"lib/kvbm-kernels/tests/kernel_roundtrip.rs" did not exist on "673822ea931352cebe5d6e49a9b66f66a8f1d272"
act_quant_fusion.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import Any
6

7
8
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
11
12
13
from torch._inductor.pattern_matcher import (
    PatternMatcherPass,
    fwd_only,
    register_replacement,
)
14
from torch._ops import OpOverload
15
16
17

from vllm.config import VllmConfig
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization.utils.quant_utils import (
19
    QuantKey,
20
21
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
22
    kFp8StaticTensorSym,
23
    kNvfp4Dynamic,
24
)
25
from vllm.platforms import current_platform
26

27
28
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
29
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
30
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
31
32
33

logger = init_logger(__name__)

34
35
36
37
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8

SILU_MUL_OP = torch.ops._C.silu_and_mul.default
38

39
40
41
FUSED_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default,  # noqa: E501
}
42
43
44
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
    torch.ops._C, "silu_and_mul_nvfp4_quant"
)
45
if silu_and_mul_nvfp4_quant_supported:
46
    FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default  # noqa: E501
47

48
49
50
51
if current_platform.is_cuda():
    FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
    FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default

52

53
54
55
56
57
class ActivationQuantPattern(ABC):
    """
    The base class for Activation+Quant fusions.
    Should not be used directly.
    """
58

59
60
61
    def __init__(
        self,
        quant_key: QuantKey,
62
    ) -> None:
63
64
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype
65

66
        assert self.quant_key in QUANT_OPS, (
67
            f"unsupported quantization scheme {self.quant_key}"
68
        )
69
        self.QUANT_OP = QUANT_OPS[self.quant_key]
70

71
        assert self.quant_key in FUSED_OPS, (
72
            f"unsupported fusion scheme {self.quant_key}"
73
        )
74
        self.FUSED_OP = FUSED_OPS[self.quant_key]
75

76
77
        self.silu_and_mul_matcher = MatcherSiluAndMul()

78
    def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
79
        kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
80
        return torch.empty(*args, **kwargs)
81

82
    @abstractmethod
83
    def register(self, pm_pass: PatternMatcherPass) -> None:
84
        raise NotImplementedError
85

86
87
88
89
90
91

class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+Fp8StaticQuant Pattern
    """

92
    def __init__(self) -> None:
93
94
        super().__init__(kFp8StaticTensorSym)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
95

96
97
98
99
100
101
102
103
    def get_inputs(self) -> list[torch.Tensor]:
        scale = self.quant_matcher.inputs()[1]
        return [
            *self.silu_and_mul_matcher.inputs(),  # input
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass) -> None:
104
105
106
        def pattern(
            input: torch.Tensor,
            scale: torch.Tensor,
107
        ) -> torch.Tensor:
108
109
110
            result_silu_mul = self.silu_and_mul_matcher(input)
            result_quant = self.quant_matcher(result_silu_mul, scale)
            return result_quant[0]
111

112
113
114
        def replacement(
            input: torch.Tensor,
            scale: torch.Tensor,
115
        ) -> torch.Tensor:
116
117
118
119
120
            d = input.shape[-1] // 2
            output_shape = input.shape[:-1] + (d,)
            result = torch.empty(
                output_shape, device=input.device, dtype=self.quant_dtype
            )
121
122
123
            at = auto_functionalized(
                self.FUSED_OP, result=result, input=input, scale=scale
            )
124
125
            return at[1]

126
127
        inps = self.get_inputs()
        pattern(*inps)
128

129
        register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
130
131
132
133
134
135
136


class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+Nvfp4Quant Pattern
    """

137
    def __init__(self) -> None:
138
        super().__init__(kNvfp4Dynamic)
139

140
141
142
143
144
145
146
147
    def get_inputs(self) -> list[torch.Tensor]:
        result = self.empty_quant(5, 32)
        output_scale = empty_i32(128, 4)
        input_ = empty_bf16(5, 64)
        scale = empty_fp32(1, 1)
        return [result, output_scale, input_, scale]

    def register(self, pm_pass: PatternMatcherPass) -> None:
148
149
150
151
152
        def pattern(
            result: torch.Tensor,
            output_scale: torch.Tensor,
            input: torch.Tensor,
            scale: torch.Tensor,
153
        ) -> tuple[torch.Tensor, torch.Tensor]:
154
155
            result_silu_mul = self.silu_and_mul_matcher(input)
            at = auto_functionalized(
156
                self.QUANT_OP,
157
                input=result_silu_mul,
158
                input_scale=scale,
159
                is_sf_swizzled_layout=True,
160
161
                output=result,
                output_scale=output_scale,
162
            )
163
            return at[1], at[2]
164

165
166
167
168
169
        def replacement(
            result: torch.Tensor,
            output_scale: torch.Tensor,
            input: torch.Tensor,
            scale: torch.Tensor,
170
        ) -> tuple[torch.Tensor, torch.Tensor]:
171
172
173
174
175
176
177
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                result_block_scale=output_scale,
                input=input,
                input_global_scale=scale,
            )
178
179
            return at[1], at[2]

180
        register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
181
182


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class SiluMulBlockQuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+BlockQuant (FP8 dynamic per-group) Pattern.
    Supports group_size 128 and 64 via QuantKey.
    Parameterized on is_scale_transposed for different scale layouts.
    """

    def __init__(
        self,
        quant_key: QuantKey,
        is_scale_transposed: bool = False,
        is_e8m0: bool = False,
        is_tma_aligned: bool = False,
    ) -> None:
        super().__init__(quant_key)
        self.quant_matcher = MatcherQuantFP8(
            quant_key,
            has_col_major_scales=is_scale_transposed,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
        )
        self.group_size = quant_key.scale.group_shape[1]
        self.is_scale_transposed = is_scale_transposed
        self.is_e8m0 = is_e8m0
        self.is_tma_aligned = is_tma_aligned

    def get_inputs(self) -> list[torch.Tensor]:
        scale = self.quant_matcher.empty_f32(1, 1)
        return self.silu_and_mul_matcher.inputs() + [scale]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        is_scale_transposed = self.is_scale_transposed

        def pattern(
            input: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            silu_out = self.silu_and_mul_matcher(input)
            result = torch.empty(
                silu_out.shape,
                device=silu_out.device,
                dtype=self.quant_dtype,
            )
            assert scale is not None
            finfo = torch.finfo(self.quant_dtype)
            _, result, scale = auto_functionalized(
                self.quant_matcher.QUANT_OP,
                input=silu_out,
                output_q=result,
                output_s=scale,
                group_size=self.group_size,
                eps=1e-10,
                fp8_min=finfo.min,
                fp8_max=finfo.max,
                scale_ue8m0=self.is_e8m0,
                dummy_is_scale_transposed=is_scale_transposed,
                dummy_is_tma_aligned=self.is_tma_aligned,
            )
            return result, scale

        def replacement(
            input: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            d = input.shape[-1] // 2
            output_shape = input.shape[:-1] + (d,)
            result = torch.empty(
                output_shape, device=input.device, dtype=self.quant_dtype
            )
            if is_scale_transposed:
                scale = torch.empty(
                    (d // self.group_size, input.shape[0]),
                    device=input.device,
                    dtype=torch.float32,
                ).permute(-1, -2)
            else:
                scale = torch.empty(
                    (input.shape[0], d // self.group_size),
                    device=input.device,
                    dtype=torch.float32,
                )
            at = auto_functionalized(
                self.FUSED_OP,
                out=result,
                input=input,
                scales=scale,
                group_size=self.group_size,
                scale_ub=None,
                is_scale_transposed=is_scale_transposed,
            )
            return at[1], at[2]

        inps = self.get_inputs()
        register_replacement(pattern, replacement, inps, fwd_only, pm_pass)


279
class ActivationQuantFusionPass(VllmPatternMatcherPass):
280
281
282
283
284
285
286
287
288
    """
    This pass fuses a pre-defined set of custom ops into fused ops.
    It uses the torch pattern matcher to find the patterns and replace them.

    Because patterns can only be registered once, the pass is a singleton.
    This will be addressed in a future version of PyTorch:
    https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
    """

289
    @enable_fake_mode
290
    def __init__(self, config: VllmConfig) -> None:
291
292
293
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
294
295
            pass_name="activation_quant_fusion_pass"
        )
296

297
298
299
        pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
        pattern_silu_mul_fp8.register(self.patterns)

300
301
302
        if silu_and_mul_nvfp4_quant_supported:
            pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
            pattern_silu_mul_nvfp4.register(self.patterns)
303

304
305
306
307
308
309
310
311
312
313
314
315
        if current_platform.is_cuda():
            for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
                for is_scale_transposed in [False, True]:
                    for is_e8m0 in [True, False]:
                        for is_tma_aligned in [False, True]:
                            SiluMulBlockQuantPattern(
                                quant_key,
                                is_scale_transposed=is_scale_transposed,
                                is_e8m0=is_e8m0,
                                is_tma_aligned=is_tma_aligned,
                            ).register(self.patterns)

316
        self.dump_patterns(config, self.patterns)
317

318
    @VllmInductorPass.time_and_log
319
    def __call__(self, graph: torch.fx.Graph) -> None:
320
321
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
322

323
    def uuid(self) -> str:
324
325
326
327
328
        return VllmInductorPass.hash_source(
            self,
            ActivationQuantPattern,
            SiluMulFp8StaticQuantPattern,
            SiluMulNvfp4QuantPattern,
329
            SiluMulBlockQuantPattern,
330
        )