rms_quant_fusion.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
import vllm.ir.ops
13
from vllm.config import VllmConfig, get_current_vllm_config
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.quantization.utils.quant_utils import (
16
17
18
    GroupShape,
    QuantKey,
    ScaleDesc,
19
20
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
21
22
23
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
24
    kNvfp4Dynamic,
25
26
    kStaticTensorScale,
)
27
from vllm.platforms import current_platform
28

29
30
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
31
32
33
34
from .matcher_utils import (
    MatcherFusedAddRMSNorm,
    MatcherQuantFP8,
)
35

36
logger = init_logger(__name__)
37
FP8_DTYPE = current_platform.fp8_dtype()
38
FP4_DTYPE = torch.uint8
39
40


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default


# TODO: extend rmsnorm quant kernels to support mixed input/weight dtypes,
# and remove this check.
def _rms_input_weight_dtype_match(match: pm.Match) -> bool:
    """Prevent fusion when rms_norm input and weight dtypes differ."""
    for node in match.nodes:
        if node.target == _RMS_NORM_OP:
            # rms_norm(x, weight, epsilon, variance_size)
            x, weight = node.args[0], node.args[1]
            if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
                return x.meta["val"].dtype == weight.meta["val"].dtype
    return True


57
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
58
59
60
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


61
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
62
63
64
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


65
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
66
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
67

68

69
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
70
71
72
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")


73
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
74

75
QUANT_OPS: dict[QuantKey, OpOverload] = {
76
77
78
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
79
}
80
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
81
    QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
82
83
84
if current_platform.is_cuda():
    QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
    QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
85
86
87
88
89
90
91
92


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
93

94
95
96
    quant: QuantKey
    fused_add: bool

97
    def __str__(self) -> str:
98
99
100
101
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
102
103


104
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
105
106
107
108
109
110
111
112
113
114
115
116
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
117
118
119
120
121
122
123
124
125
126
127
128
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
129
130
131
132
}


class RMSNormQuantPattern:
133
134
135
136
137
138
    def __init__(
        self,
        epsilon: float,
        key: FusedRMSQuantKey,
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
139
        is_tma_aligned: bool = False,
140
    ) -> None:
141
142
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
143
144
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
145

146
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
147
148
        self.FUSED_OP = FUSED_OPS[key]

149
150
151
        if key.fused_add:
            self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

152
        self.quant_matcher = MatcherQuantFP8(
153
154
155
156
            key.quant,
            has_col_major_scales=has_col_major_scales,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
157
        )
158

159
160

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
161
162
163
    def __init__(
        self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
    ) -> None:
164
165
166
167
168
169
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
170
171
        super().__init__(epsilon, fused_key)

172
    def register(self, pm_pass: PatternMatcherPass) -> None:
173
        # Cannot use methods, as the self argument affects tracing
174
175
176
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> torch.Tensor:
177
            result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
178
            return self.quant_matcher(result_rms, scale)[0]
179

180
181
182
        def replacement(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> torch.Tensor:
183
184
185
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
186
187
188
189
190
191
192
193
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
194
195
196
197
198

            # result
            return at[1]

        inputs = [
199
200
            empty_bf16(5, 16),  # input
            empty_bf16(16),  # weight
201
            self.quant_matcher.inputs()[1],  # scale
202
        ]
203
        pattern(*inputs)
204

205
206
207
208
209
210
211
212
        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
            extra_check=_rms_input_weight_dtype_match,
        )
213
214
215


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
216
217
218
    def __init__(
        self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
    ) -> None:
219
220
221
222
223
224
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
225
226
        super().__init__(epsilon, key)

227
    def register(self, pm_pass: PatternMatcherPass) -> None:
228
229
230
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
231
            residual: torch.Tensor,
232
            scale: torch.Tensor,
233
        ) -> tuple[torch.Tensor, torch.Tensor]:
234
235
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
236

237
            return result, residual
238

239
240
241
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
242
            residual: torch.Tensor,
243
            scale: torch.Tensor,
244
        ) -> tuple[torch.Tensor, torch.Tensor]:
245
246
247
248
249
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
250
251
252
253
254
255
256
257
258
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
259
260
261
262
263

            # result, residual
            return at[1], at[2]

        inputs = [
264
265
266
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
267
268
269
270
271
272
273
274
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
275
            extra_check=_rms_input_weight_dtype_match,
276
        )
277
278


279
280
281
282
283
284
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
285
        symmetric: bool = True,
286
        is_e8m0: bool = False,
287
288
        has_col_major_scales: bool = True,
        is_tma_aligned: bool = True,
289
    ) -> None:
290
291
292
293
294
295
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
296
        self.is_e8m0 = is_e8m0
297
298
        self.has_col_major_scales = has_col_major_scales
        self.is_tma_aligned = is_tma_aligned
299
        super().__init__(
300
301
302
303
304
            epsilon,
            key,
            has_col_major_scales=has_col_major_scales,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
305
        )
306

307
308
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
309
310
311
312
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
313
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
314
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
            result = torch.empty(
                result_rms.shape,
                device=result_rms.device,
                dtype=self.quant_matcher.quant_key.dtype,
            )
            assert scale is not None
            finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
            fp8_min = finfo.min
            fp8_max = finfo.max

            _, result, scale = auto_functionalized(
                self.quant_matcher.QUANT_OP,
                input=result_rms,
                output_q=result,
                output_s=scale,
                group_size=self.quant_matcher.quant_key.scale.group_shape[1],
                eps=1e-10,
                fp8_min=fp8_min,
                fp8_max=fp8_max,
                scale_ue8m0=self.quant_matcher.is_e8m0,
                dummy_is_scale_transposed=self.has_col_major_scales,
                dummy_is_tma_aligned=self.is_tma_aligned,
            )

339
340
341
            return result, residual, scale

        def replacement(
342
343
344
345
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
346
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
347
348
349
350
351
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
352

353
354
355
356
357
358
359
360
361
362
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
                group_size=self.group_shape[1],
363
                is_scale_transposed=self.has_col_major_scales,
364
365
366
367
368
            )

            # result, residual, scale
            return at[1], at[3], at[2]

369
370
        scale = self.quant_matcher.empty_f32(1, 1)

371
372
373
        pm.register_replacement(
            pattern,
            replacement,
374
            self.rmsnorm_matcher.inputs() + [scale],
375
376
            pm.fwd_only,
            pm_pass,
377
            extra_check=_rms_input_weight_dtype_match,
378
379
380
381
382
383
384
385
386
        )


class RMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
387
        symmetric: bool = True,
388
        is_e8m0: bool = False,
389
390
        has_col_major_scales: bool = True,
        is_tma_aligned: bool = True,
391
    ) -> None:
392
393
394
395
396
397
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
398
399
        self.has_col_major_scales = has_col_major_scales
        self.is_tma_aligned = is_tma_aligned
400
        super().__init__(
401
402
403
404
405
            epsilon,
            key,
            has_col_major_scales=self.has_col_major_scales,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
406
        )
407

408
409
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
410
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
411
        ) -> tuple[torch.Tensor, torch.Tensor]:
412
            result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
            result = torch.empty(
                result_rms.shape,
                device=result_rms.device,
                dtype=self.quant_matcher.quant_key.dtype,
            )
            assert scale is not None
            finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
            fp8_min = finfo.min
            fp8_max = finfo.max

            _, result, scale = auto_functionalized(
                self.quant_matcher.QUANT_OP,
                input=result_rms,
                output_q=result,
                output_s=scale,
                group_size=self.quant_matcher.quant_key.scale.group_shape[1],
                eps=1e-10,
                fp8_min=fp8_min,
                fp8_max=fp8_max,
                scale_ue8m0=self.quant_matcher.is_e8m0,
                dummy_is_scale_transposed=self.has_col_major_scales,
                dummy_is_tma_aligned=self.is_tma_aligned,
            )

437
438
            return result, scale

439
        def replacement(
440
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
441
        ) -> tuple[torch.Tensor, torch.Tensor]:
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
                group_size=self.group_shape[1],
457
                is_scale_transposed=self.has_col_major_scales,
458
459
460
461
462
463
464
465
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
466
467
468
469
470
            [
                empty_bf16(5, 16),  # input
                empty_bf16(16),  # weight
                self.quant_matcher.empty_f32(1, 1),  # scale
            ],
471
472
            pm.fwd_only,
            pm_pass,
473
            extra_check=_rms_input_weight_dtype_match,
474
475
476
        )


477
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
478
479
480
481
482
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
483
484
        symmetric: bool = True,
    ) -> None:
485
        scale = ScaleDesc(torch.float32, False, group_shape)
486
487
488
489
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
490
491
        super().__init__(epsilon, key)

492
493
494
495
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
496
            result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
497
            # result, scale
498
            return self.quant_matcher(result_rms)  # type: ignore[no-any-return]
499

500
501
502
        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
503
504
505
506
507
508
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
509
510
511
512
513
514
515
516
517
518
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
519
520
521
522
523
524
525

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
526
527
528
529
            [
                empty_bf16(5, 16),  # input
                empty_bf16(16),  # weight
            ],
530
531
            pm.fwd_only,
            pm_pass,
532
            extra_check=_rms_input_weight_dtype_match,
533
        )
534
535
536


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
537
538
539
540
541
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
542
543
        symmetric: bool = True,
    ) -> None:
544
        scale = ScaleDesc(torch.float32, False, group_shape)
545
546
547
548
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
549
550
        super().__init__(epsilon, key)

551
552
553
554
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
555
556
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
557

558
            return result, residual, scale
559

560
        def replacement(
561
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
562
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
563
564
565
566
567
568
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
569
570
571
572
573
574
575
576
577
578
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
579
580
581

            # result, residual, scale
            return at[1], at[3], at[2]
582

583
584
585
        pm.register_replacement(
            pattern,
            replacement,
586
            self.rmsnorm_matcher.inputs(),
587
588
            pm.fwd_only,
            pm_pass,
589
            extra_check=_rms_input_weight_dtype_match,
590
591
592
593
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
594
    """
595
596
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
597
598
    """

599
    @enable_fake_mode
600
    def __init__(self, config: VllmConfig) -> None:
601
602
603
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
604
605
            pass_name="rmsnorm_quant_fusion_pass"
        )
606

607
608
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
609
        for epsilon in [1e-5, 1e-6]:
610
            # Fuse fused_add_rms_norm + static fp8 quant
611
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
612
613
                self.patterns
            )
614

615
616
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
617
618

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
619
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
620
621
                self.patterns
            )
622

623
624
625
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

626
627
628
629
630
            # Only register group quant patterns on CUDA where the C++ op exists
            if current_platform.is_cuda():
                for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
                    for has_col_major_scales in [True, False]:
                        for is_e8m0 in [True, False]:
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                            for is_tma_aligned in [False, True]:
                                # Fuse fused_add_rms_norm + fp8 group quant
                                FusedAddRMSNormGroupQuantPattern(
                                    epsilon,
                                    FP8_DTYPE,
                                    group_shape=group_shape,
                                    is_e8m0=is_e8m0,
                                    has_col_major_scales=has_col_major_scales,
                                    is_tma_aligned=is_tma_aligned,
                                ).register(self.patterns)

                                # Fuse rms_norm + fp8 group quant
                                RMSNormGroupQuantPattern(
                                    epsilon,
                                    FP8_DTYPE,
                                    group_shape=group_shape,
                                    is_e8m0=is_e8m0,
                                    has_col_major_scales=has_col_major_scales,
                                    is_tma_aligned=is_tma_aligned,
                                ).register(self.patterns)
651

652
        self.dump_patterns(config, self.patterns)
653

654
    @VllmInductorPass.time_and_log
655
    def __call__(self, graph: fx.Graph) -> None:
656
657
658
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

659
    def uuid(self) -> str:
660
661
        return self.hash_source(
            self,
662
            RMSNormGroupQuantPattern,
663
664
665
666
667
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
668
            FusedAddRMSNormGroupQuantPattern,
669
        )