fusion.py 18.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig, get_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
    GroupShape,
    QuantKey,
    ScaleDesc,
18
19
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
20
21
22
23
24
25
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
    kStaticTensorScale,
)
26
27
28
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    cutlass_block_fp8_supported,
)
29
from vllm.platforms import current_platform
30
31
32
33
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    should_use_deepgemm_for_fp8_linear_for_nk,
)
34

35
from .inductor_pass import enable_fake_mode
36
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
37
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
38

39
logger = init_logger(__name__)
40
FP8_DTYPE = current_platform.fp8_dtype()
41
FP4_DTYPE = torch.uint8
42
43
44
45
46
47
48
49
50
51


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


52
53
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
54

55

56
57
58
59
def empty_i64(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")


60
61
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
62

63
QUANT_OPS: dict[QuantKey, OpOverload] = {
64
65
66
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
67
}
68
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
69
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
70
71
72
if current_platform.is_cuda():
    QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
    QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
73
74
75
76
77
78
79
80


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
81

82
83
84
85
    quant: QuantKey
    fused_add: bool

    def __str__(self):
86
87
88
89
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
90
91


92
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
93
94
95
96
97
98
99
100
101
102
103
104
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
105
106
107
108
109
110
111
112
113
114
115
116
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
117
118
119
120
121
122
123
}


class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
124
125
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
126

127
128
129
130
131
132
133
134
135
        # groupwise FP8 linear uses col major scales if deepgemm and cutlass
        using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
            self.model_dtype,
            config.model_config.hf_config.intermediate_size,
            config.model_config.hf_config.hidden_size,
        )
        use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
        use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False

136
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
137
138
        self.FUSED_OP = FUSED_OPS[key]

139
140
141
142
143
        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
144
145
146
        self.quant_matcher = MatcherQuantFP8(
            key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
        )
147

148
149

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
150
151
152
153
154
155
156
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
157
158
159
160
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
161
162
163
        def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]
164

165
166
167
168
        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)
169

170
171
172
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
173
174
175
176
177
178
179
180
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
181
182
183
184
185

            # result
            return at[1]

        inputs = [
186
187
188
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
189
        ]
190
        pattern(*inputs)
191

192
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
193
194
195


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
196
197
198
199
200
201
202
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
203
204
        super().__init__(epsilon, key)

205
    def register(self, pm_pass: PatternMatcherPass):
206
207
208
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
209
            residual: torch.Tensor,
210
211
            scale: torch.Tensor,
        ):
212
213
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
214

215
            return result, residual
216

217
218
219
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
220
            residual: torch.Tensor,
221
222
            scale: torch.Tensor,
        ):
223
224
225
226
227
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
228
229
230
231
232
233
234
235
236
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
237
238
239
240
241

            # result, residual
            return at[1], at[2]

        inputs = [
242
243
244
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
245
246
247
248
249
250
251
252
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
253
        )
254
255


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
            return result, residual, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(
                input, transposed=self.quant_matcher.use_col_major_scales
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
                group_size=self.group_shape[1],
                is_scale_transposed=self.quant_matcher.use_col_major_scales,
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )


class RMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            result, scale = self.quant_matcher(result_rms)
            return result, scale

        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(
                input, transposed=self.quant_matcher.use_col_major_scales
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
                group_size=self.group_shape[1],
                is_scale_transposed=self.quant_matcher.use_col_major_scales,
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )


370
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
371
372
373
374
375
376
377
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
378
        scale = ScaleDesc(torch.float32, False, group_shape)
379
380
381
382
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
383
384
        super().__init__(epsilon, key)

385
    def register(self, pm_pass: PatternMatcherPass):
386
387
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
388
            # result, scale
389
            return self.quant_matcher(result_rms)
390

391
392
393
394
395
396
397
        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
398
399
400
401
402
403
404
405
406
407
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
408
409
410
411
412
413
414

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
415
            self.rmsnorm_matcher.inputs(),
416
417
            pm.fwd_only,
            pm_pass,
418
        )
419
420
421


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
422
423
424
425
426
427
428
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
429
        scale = ScaleDesc(torch.float32, False, group_shape)
430
431
432
433
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
434
435
        super().__init__(epsilon, key)

436
    def register(self, pm_pass: PatternMatcherPass):
437
438
439
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
440

441
            return result, residual, scale
442

443
        def replacement(
444
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
445
        ):
446
447
448
449
450
451
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
452
453
454
455
456
457
458
459
460
461
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
462
463
464

            # result, residual, scale
            return at[1], at[3], at[2]
465

466
467
468
        pm.register_replacement(
            pattern,
            replacement,
469
            self.rmsnorm_matcher.inputs(),
470
471
            pm.fwd_only,
            pm_pass,
472
473
474
475
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
476
    """
477
478
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
479
480
    """

481
    @enable_fake_mode
482
    def __init__(self, config: VllmConfig):
483
484
485
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
486
487
            pass_name="rmsnorm_quant_fusion_pass"
        )
488

489
490
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
491
        for epsilon in [1e-5, 1e-6]:
492
            # Fuse fused_add_rms_norm + fp8 group quant
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            # Only register group quant patterns on CUDA where the C++ op exists
            if current_platform.is_cuda():
                FusedAddRMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
                ).register(self.patterns)

                # Fuse rms_norm + fp8 group quant
                RMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
                ).register(self.patterns)

                FusedAddRMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
                ).register(self.patterns)

                # Fuse rms_norm + fp8 group quant
                RMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
                ).register(self.patterns)
512

513
            # Fuse fused_add_rms_norm + static fp8 quant
514
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
515
516
                self.patterns
            )
517

518
519
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
520
521

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
522
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
523
524
                self.patterns
            )
525

526
527
528
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

529
        self.dump_patterns(config, self.patterns)
530

531
    @VllmInductorPass.time_and_log
532
    def __call__(self, graph: fx.Graph):
533
534
535
536
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
537
538
        return self.hash_source(
            self,
539
            RMSNormGroupQuantPattern,
540
541
542
543
544
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
545
            FusedAddRMSNormGroupQuantPattern,
546
        )