fusion.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig, get_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
18
19
20
21
22
23
    GroupShape,
    QuantKey,
    ScaleDesc,
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
    kStaticTensorScale,
)
24
from vllm.platforms import current_platform
25

26
from .inductor_pass import enable_fake_mode
27
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
28
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
29

30
logger = init_logger(__name__)
31
FP8_DTYPE = current_platform.fp8_dtype()
32
FP4_DTYPE = torch.uint8
33
34
35
36
37
38
39
40
41
42


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


43
44
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
45

46

47
48
49
50
def empty_i64(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")


51
52
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
53

54
QUANT_OPS: dict[QuantKey, OpOverload] = {
55
56
57
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
58
}
59
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
60
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
61
62
63
64
65
66
67
68


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
69

70
71
72
73
    quant: QuantKey
    fused_add: bool

    def __str__(self):
74
75
76
77
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
78
79


80
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
81
82
83
84
85
86
87
88
89
90
91
92
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
93
94
95
96
97
98
99
}


class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
100
101
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
102

103
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
104
105
        self.FUSED_OP = FUSED_OPS[key]

106
107
108
109
110
111
112
        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
        self.quant_matcher = MatcherQuantFP8(key.quant)

113
114

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
115
116
117
118
119
120
121
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
122
123
124
125
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
126
127
128
        def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]
129

130
131
132
133
        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)
134

135
136
137
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
138
139
140
141
142
143
144
145
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
146
147
148
149
150

            # result
            return at[1]

        inputs = [
151
152
153
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
154
        ]
155
        pattern(*inputs)
156

157
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
158
159
160


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
161
162
163
164
165
166
167
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
168
169
        super().__init__(epsilon, key)

170
    def register(self, pm_pass: PatternMatcherPass):
171
172
173
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
174
            residual: torch.Tensor,
175
176
            scale: torch.Tensor,
        ):
177
178
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
179

180
            return result, residual
181

182
183
184
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
185
            residual: torch.Tensor,
186
187
            scale: torch.Tensor,
        ):
188
189
190
191
192
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
193
194
195
196
197
198
199
200
201
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
202
203
204
205
206

            # result, residual
            return at[1], at[2]

        inputs = [
207
208
209
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
210
211
212
213
214
215
216
217
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
218
        )
219
220
221


class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
222
223
224
225
226
227
228
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
229
        scale = ScaleDesc(torch.float32, False, group_shape)
230
231
232
233
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
234
235
        super().__init__(epsilon, key)

236
    def register(self, pm_pass: PatternMatcherPass):
237
238
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
239
            # result, scale
240
            return self.quant_matcher(result_rms)
241

242
243
244
245
246
247
248
        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
249
250
251
252
253
254
255
256
257
258
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
259
260
261
262
263
264
265

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
266
            self.rmsnorm_matcher.inputs(),
267
268
            pm.fwd_only,
            pm_pass,
269
        )
270
271
272


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
273
274
275
276
277
278
279
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
280
        scale = ScaleDesc(torch.float32, False, group_shape)
281
282
283
284
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
285
286
        super().__init__(epsilon, key)

287
    def register(self, pm_pass: PatternMatcherPass):
288
289
290
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
291

292
            return result, residual, scale
293

294
        def replacement(
295
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
296
        ):
297
298
299
300
301
302
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
303
304
305
306
307
308
309
310
311
312
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
313
314
315

            # result, residual, scale
            return at[1], at[3], at[2]
316

317
318
319
        pm.register_replacement(
            pattern,
            replacement,
320
            self.rmsnorm_matcher.inputs(),
321
322
            pm.fwd_only,
            pm_pass,
323
324
325
326
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
327
    """
328
329
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
330
331
    """

332
    @enable_fake_mode
333
    def __init__(self, config: VllmConfig):
334
335
336
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
337
338
            pass_name="rmsnorm_quant_fusion_pass"
        )
339

340
341
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
342
        for epsilon in [1e-5, 1e-6]:
343
            # Fuse fused_add_rms_norm + static fp8 quant
344
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
345
346
                self.patterns
            )
347

348
349
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
350
351

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
352
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
353
354
                self.patterns
            )
355

356
357
358
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

359
        self.dump_patterns(config, self.patterns)
360

361
    @VllmInductorPass.time_and_log
362
    def __call__(self, graph: fx.Graph):
363
364
365
366
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
367
368
369
370
371
372
373
374
        return self.hash_source(
            self,
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
        )