fusion.py 18.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig, get_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
    GroupShape,
    QuantKey,
    ScaleDesc,
18
19
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
20
21
22
23
24
25
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
    kStaticTensorScale,
)
26
from vllm.platforms import current_platform
27

28
from .inductor_pass import enable_fake_mode
29
30
31
32
33
from .matcher_utils import (
    MatcherFusedAddRMSNorm,
    MatcherQuantFP8,
    MatcherRMSNorm,
)
34
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
35

36
logger = init_logger(__name__)
37
FP8_DTYPE = current_platform.fp8_dtype()
38
FP4_DTYPE = torch.uint8
39
40
41
42
43
44
45
46
47
48


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


49
50
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
51

52

53
54
55
def empty_i64(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")

56

57
58
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
59

60
QUANT_OPS: dict[QuantKey, OpOverload] = {
61
62
63
    # kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    # kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    # kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
64
}
65
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
66
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
67
68
69
if current_platform.is_cuda():
    QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
    QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
70
71
72
73
74
75
76
77


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
78

79
80
81
82
    quant: QuantKey
    fused_add: bool

    def __str__(self):
83
84
85
86
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
87
88


89
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
90
91
92
93
94
95
96
97
98
99
100
101
    # FusedRMSQuantKey(
    #     kFp8StaticTensorSym, False
    # ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    # FusedRMSQuantKey(
    #     kFp8StaticTensorSym, True
    # ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    # FusedRMSQuantKey(
    #     kFp8DynamicTokenSym, False
    # ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    # FusedRMSQuantKey(
    #     kFp8DynamicTokenSym, True
    # ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
102
103
104
105
106
107
108
109
110
111
112
113
    # FusedRMSQuantKey(
    #     kFp8Dynamic128Sym, False
    # ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    # FusedRMSQuantKey(
    #     kFp8Dynamic128Sym, True
    # ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    # FusedRMSQuantKey(
    #     kFp8Dynamic64Sym, False
    # ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    # FusedRMSQuantKey(
    #     kFp8Dynamic64Sym, True
    # ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
114
115
116
117
}


class RMSNormQuantPattern:
118
119
120
121
122
123
124
    def __init__(
        self,
        epsilon: float,
        key: FusedRMSQuantKey,
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
    ):
125
126
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
127
128
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
129

130
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
131
132
        self.FUSED_OP = FUSED_OPS[key]

133
134
135
136
137
        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
138
        self.quant_matcher = MatcherQuantFP8(
139
            key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
140
        )
141
142
143


class RMSNormStaticQuantPattern(RMSNormQuantPattern):
144
145
146
147
148
149
150
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
151
152
153
154
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
155
156
157
        def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]
158

159
160
161
162
        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)
163

164
165
166
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
167
168
169
170
171
172
173
174
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
175
176
177
178
179

            # result
            return at[1]

        inputs = [
180
181
182
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
183
        ]
184
        pattern(*inputs)
185

186
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
187
188
189


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
190
191
192
193
194
195
196
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
197
198
        super().__init__(epsilon, key)

199
    def register(self, pm_pass: PatternMatcherPass):
200
201
202
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
203
            residual: torch.Tensor,
204
205
            scale: torch.Tensor,
        ):
206
207
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
208

209
            return result, residual
210

211
212
213
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
214
            residual: torch.Tensor,
215
216
            scale: torch.Tensor,
        ):
217
218
219
220
221
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
222
223
224
225
226
227
228
229
230
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
231
232
233
234
235

            # result, residual
            return at[1], at[2]

        inputs = [
236
237
238
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
239
240
241
242
243
244
245
246
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
247
        )
248
249


250
251
252
253
254
255
256
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        symmetric=True,
257
258
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
259
260
261
262
263
264
265
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
266
267
268
269
270
        self.has_col_major_scales = has_col_major_scales
        self.is_e8m0 = is_e8m0
        super().__init__(
            epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
        )
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
            return result, residual, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
286
            scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
287
288
289
290
291
292
293
294
295
296
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
                group_size=self.group_shape[1],
297
                is_scale_transposed=self.has_col_major_scales,
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )


class RMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        symmetric=True,
319
320
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
321
322
323
324
325
326
327
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
328
329
330
        super().__init__(
            epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
        )
331
332
333
334
335
336
337
338
339
340
341
342
343
344

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            result, scale = self.quant_matcher(result_rms)
            return result, scale

        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(
345
                input, transposed=self.quant_matcher.has_col_major_scales
346
347
348
349
350
351
352
353
354
355
356
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
                group_size=self.group_shape[1],
357
                is_scale_transposed=self.quant_matcher.has_col_major_scales,
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )


372
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
373
374
375
376
377
378
379
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
380
        scale = ScaleDesc(torch.float32, False, group_shape)
381
382
383
384
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
385
386
        super().__init__(epsilon, key)

387
    def register(self, pm_pass: PatternMatcherPass):
388
389
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
390
            # result, scale
391
            return self.quant_matcher(result_rms)
392

393
394
395
396
397
398
399
        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
400
401
402
403
404
405
406
407
408
409
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
410
411
412
413
414
415
416

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
417
            self.rmsnorm_matcher.inputs(),
418
419
            pm.fwd_only,
            pm_pass,
420
        )
421
422
423


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
424
425
426
427
428
429
430
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
431
        scale = ScaleDesc(torch.float32, False, group_shape)
432
433
434
435
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
436
437
        super().__init__(epsilon, key)

438
    def register(self, pm_pass: PatternMatcherPass):
439
440
441
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
442

443
            return result, residual, scale
444

445
        def replacement(
446
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
447
        ):
448
449
450
451
452
453
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
454
455
456
457
458
459
460
461
462
463
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
464
465
466

            # result, residual, scale
            return at[1], at[3], at[2]
467

468
469
470
        pm.register_replacement(
            pattern,
            replacement,
471
            self.rmsnorm_matcher.inputs(),
472
473
            pm.fwd_only,
            pm_pass,
474
475
476
477
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
478
    """
479
480
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
481
482
    """

483
    @enable_fake_mode
484
    def __init__(self, config: VllmConfig):
485
486
487
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
488
489
            pass_name="rmsnorm_quant_fusion_pass"
        )
490

491
492
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
493
        for epsilon in [1e-5, 1e-6]:
494
            # Fuse fused_add_rms_norm + static fp8 quant
495
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
496
497
                self.patterns
            )
498

499
500
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
501
502

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
503
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
504
505
                self.patterns
            )
506

507
508
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
509

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            # Only register group quant patterns on CUDA where the C++ op exists
            if current_platform.is_cuda():
                for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
                    for has_col_major_scales in [True, False]:
                        for is_e8m0 in [True, False]:
                            # Fuse fused_add_rms_norm + fp8 group quant
                            FusedAddRMSNormGroupQuantPattern(
                                epsilon,
                                FP8_DTYPE,
                                group_shape=group_shape,
                                has_col_major_scales=has_col_major_scales,
                                is_e8m0=is_e8m0,
                            ).register(self.patterns)

                            # Fuse rms_norm + fp8 group quant
                            RMSNormGroupQuantPattern(
                                epsilon,
                                FP8_DTYPE,
                                group_shape=group_shape,
                                has_col_major_scales=has_col_major_scales,
                                is_e8m0=is_e8m0,
                            ).register(self.patterns)

533
        self.dump_patterns(config, self.patterns)
534

535
    @VllmInductorPass.time_and_log
536
    def __call__(self, graph: fx.Graph):
537
538
539
540
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
541
542
        return self.hash_source(
            self,
543
            RMSNormGroupQuantPattern,
544
545
546
547
548
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
549
            FusedAddRMSNormGroupQuantPattern,
550
        )