activation_quant_fusion.py 6.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from abc import ABC, abstractmethod

6
7
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
8
9
10
11
12
from torch._inductor.pattern_matcher import (
    PatternMatcherPass,
    fwd_only,
    register_replacement,
)
13
from torch._ops import OpOverload
14
15
16

from vllm.config import VllmConfig
from vllm.logger import init_logger
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
19
20
21
    QuantKey,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
22
from vllm.platforms import current_platform
23

24
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
25
from .inductor_pass import enable_fake_mode
26
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
27
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
28
29
30

logger = init_logger(__name__)

31
32
33
34
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8

SILU_MUL_OP = torch.ops._C.silu_and_mul.default
35

36
37
38
FUSED_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default,  # noqa: E501
}
39
40
41
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
    torch.ops._C, "silu_and_mul_nvfp4_quant"
)
42
if silu_and_mul_nvfp4_quant_supported:
43
    FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default  # noqa: E501
44
45


46
47
48
49
50
class ActivationQuantPattern(ABC):
    """
    The base class for Activation+Quant fusions.
    Should not be used directly.
    """
51

52
53
54
55
56
57
    def __init__(
        self,
        quant_key: QuantKey,
    ):
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
58

59
        assert self.quant_key in QUANT_OPS, (
60
            f"unsupported quantization scheme {self.quant_key}"
61
        )
62
        self.QUANT_OP = QUANT_OPS[self.quant_key]
63

64
        assert self.quant_key in FUSED_OPS, (
65
            f"unsupported fusion scheme {self.quant_key}"
66
        )
67
        self.FUSED_OP = FUSED_OPS[self.quant_key]
68

69
70
        self.silu_and_mul_matcher = MatcherSiluAndMul()

71
    def empty_quant(self, *args, **kwargs):
72
        kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
73
        return torch.empty(*args, **kwargs)
74

75
76
77
    @abstractmethod
    def register(self, pm_pass: PatternMatcherPass):
        raise NotImplementedError
78

79
80
81
82
83
84

class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+Fp8StaticQuant Pattern
    """

85
86
87
    def __init__(self):
        super().__init__(kFp8StaticTensorSym)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
88
89

    def register(self, pm_pass: PatternMatcherPass):
90
91
92
93
        def pattern(
            input: torch.Tensor,
            scale: torch.Tensor,
        ):
94
95
96
            result_silu_mul = self.silu_and_mul_matcher(input)
            result_quant = self.quant_matcher(result_silu_mul, scale)
            return result_quant[0]
97

98
99
100
101
        def replacement(
            input: torch.Tensor,
            scale: torch.Tensor,
        ):
102
103
104
105
106
            d = input.shape[-1] // 2
            output_shape = input.shape[:-1] + (d,)
            result = torch.empty(
                output_shape, device=input.device, dtype=self.quant_dtype
            )
107
108
109
            at = auto_functionalized(
                self.FUSED_OP, result=result, input=input, scale=scale
            )
110
111
112
            return at[1]

        inputs = [
113
114
            *self.silu_and_mul_matcher.inputs(),  # input
            self.quant_matcher.inputs()[1],  # scale
115
        ]
116
        pattern(*inputs)
117
118
119
120
121
122
123
124
125
126
127
128
129

        register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)


class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+Nvfp4Quant Pattern
    """

    def __init__(self):
        super().__init__(kNvfp4Quant)

    def register(self, pm_pass: PatternMatcherPass):
130
131
132
133
134
135
        def pattern(
            result: torch.Tensor,
            output_scale: torch.Tensor,
            input: torch.Tensor,
            scale: torch.Tensor,
        ):
136
137
            result_silu_mul = self.silu_and_mul_matcher(input)
            at = auto_functionalized(
138
139
                self.QUANT_OP,
                output=result,
140
                input=result_silu_mul,
141
142
143
                output_scale=output_scale,
                input_scale=scale,
            )
144
            return at[1], at[2]
145

146
147
148
149
150
151
152
153
154
155
156
157
158
        def replacement(
            result: torch.Tensor,
            output_scale: torch.Tensor,
            input: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                result_block_scale=output_scale,
                input=input,
                input_global_scale=scale,
            )
159
160
161
162
163
164
            return at[1], at[2]

        inputs = [
            self.empty_quant(5, 32),  # result
            empty_i32(128, 4),  # output_scale
            empty_bf16(5, 64),  # input
165
            empty_fp32(1, 1),  # scale
166
167
168
        ]

        register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
169
170


171
class ActivationQuantFusionPass(VllmPatternMatcherPass):
172
173
174
175
176
177
178
179
180
    """
    This pass fuses a pre-defined set of custom ops into fused ops.
    It uses the torch pattern matcher to find the patterns and replace them.

    Because patterns can only be registered once, the pass is a singleton.
    This will be addressed in a future version of PyTorch:
    https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
    """

181
    @enable_fake_mode
182
183
184
185
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
186
187
            pass_name="activation_quant_fusion_pass"
        )
188

189
190
191
        pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
        pattern_silu_mul_fp8.register(self.patterns)

192
193
194
        if silu_and_mul_nvfp4_quant_supported:
            pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
            pattern_silu_mul_nvfp4.register(self.patterns)
195

196
        self.dump_patterns(config, self.patterns)
197

198
199
200
201
    @VllmInductorPass.time_and_log
    def __call__(self, graph: torch.fx.Graph):
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
202
203

    def uuid(self):
204
205
206
207
208
209
        return VllmInductorPass.hash_source(
            self,
            ActivationQuantPattern,
            SiluMulFp8StaticQuantPattern,
            SiluMulNvfp4QuantPattern,
        )