"tests/planner/scaling/disagg_planner.yaml" did not exist on "d23d48ba26764be8d7ad5068964a42441d6a6598"
rms_quant_fusion.py 22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig, get_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
    GroupShape,
    QuantKey,
    ScaleDesc,
18
19
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
20
21
22
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
23
    kNvfp4Dynamic,
24
25
    kStaticTensorScale,
)
26
from vllm.platforms import current_platform
27

28
29
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
30
31
32
33
34
from .matcher_utils import (
    MatcherFusedAddRMSNorm,
    MatcherQuantFP8,
    MatcherRMSNorm,
)
35

36
logger = init_logger(__name__)
37
FP8_DTYPE = current_platform.fp8_dtype()
38
FP4_DTYPE = torch.uint8
39
40


41
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
42
43
44
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


45
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
46
47
48
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


49
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
50
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
51

52

53
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
54
55
56
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")


57
58
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
59

60
QUANT_OPS: dict[QuantKey, OpOverload] = {
61
62
63
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
64
}
65
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
66
    QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
67
68
69
if current_platform.is_cuda():
    QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
    QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
70
71
72
73
74
75
76
77


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
78

79
80
81
    quant: QuantKey
    fused_add: bool

82
    def __str__(self) -> str:
83
84
85
86
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
87
88


89
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
90
91
92
93
94
95
96
97
98
99
100
101
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
102
103
104
105
106
107
108
109
110
111
112
113
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
114
115
116
117
}


class RMSNormQuantPattern:
118
119
120
121
122
123
    def __init__(
        self,
        epsilon: float,
        key: FusedRMSQuantKey,
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
124
        is_tma_aligned: bool = False,
125
    ) -> None:
126
127
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
128
129
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
130

131
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
132
133
        self.FUSED_OP = FUSED_OPS[key]

134
135
136
137
138
        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
139
        self.quant_matcher = MatcherQuantFP8(
140
141
142
143
            key.quant,
            has_col_major_scales=has_col_major_scales,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
144
        )
145

146
147

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
148
149
150
    def __init__(
        self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
    ) -> None:
151
152
153
154
155
156
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
157
158
        super().__init__(epsilon, fused_key)

159
    def register(self, pm_pass: PatternMatcherPass) -> None:
160
        # Cannot use methods, as the self argument affects tracing
161
162
163
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> torch.Tensor:
164
165
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]
166

167
168
169
        def replacement(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> torch.Tensor:
170
171
172
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)
173

174
175
176
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
177
178
179
180
181
182
183
184
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
185
186
187
188
189

            # result
            return at[1]

        inputs = [
190
191
192
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
193
        ]
194
        pattern(*inputs)
195

196
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
197
198
199


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
200
201
202
    def __init__(
        self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
    ) -> None:
203
204
205
206
207
208
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
209
210
        super().__init__(epsilon, key)

211
    def register(self, pm_pass: PatternMatcherPass) -> None:
212
213
214
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
215
            residual: torch.Tensor,
216
            scale: torch.Tensor,
217
        ) -> tuple[torch.Tensor, torch.Tensor]:
218
219
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
220

221
            return result, residual
222

223
224
225
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
226
            residual: torch.Tensor,
227
            scale: torch.Tensor,
228
        ) -> tuple[torch.Tensor, torch.Tensor]:
229
230
231
232
233
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
234
235
236
237
238
239
240
241
242
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
243
244
245
246
247

            # result, residual
            return at[1], at[2]

        inputs = [
248
249
250
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
251
252
253
254
255
256
257
258
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
259
        )
260
261


262
263
264
265
266
267
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
268
        symmetric: bool = True,
269
        is_e8m0: bool = False,
270
271
        has_col_major_scales: bool = True,
        is_tma_aligned: bool = True,
272
    ) -> None:
273
274
275
276
277
278
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
279
        self.is_e8m0 = is_e8m0
280
281
        self.has_col_major_scales = has_col_major_scales
        self.is_tma_aligned = is_tma_aligned
282
        super().__init__(
283
284
285
286
287
            epsilon,
            key,
            has_col_major_scales=has_col_major_scales,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
288
        )
289

290
291
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
292
293
294
295
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
296
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
297
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
            result = torch.empty(
                result_rms.shape,
                device=result_rms.device,
                dtype=self.quant_matcher.quant_key.dtype,
            )
            assert scale is not None
            finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
            fp8_min = finfo.min
            fp8_max = finfo.max

            _, result, scale = auto_functionalized(
                self.quant_matcher.QUANT_OP,
                input=result_rms,
                output_q=result,
                output_s=scale,
                group_size=self.quant_matcher.quant_key.scale.group_shape[1],
                eps=1e-10,
                fp8_min=fp8_min,
                fp8_max=fp8_max,
                scale_ue8m0=self.quant_matcher.is_e8m0,
                dummy_is_scale_transposed=self.has_col_major_scales,
                dummy_is_tma_aligned=self.is_tma_aligned,
            )

322
323
324
            return result, residual, scale

        def replacement(
325
326
327
328
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
329
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
330
331
332
333
334
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
335

336
337
338
339
340
341
342
343
344
345
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
                group_size=self.group_shape[1],
346
                is_scale_transposed=self.has_col_major_scales,
347
348
349
350
351
            )

            # result, residual, scale
            return at[1], at[3], at[2]

352
353
        scale = self.quant_matcher.empty_f32(1, 1)

354
355
356
        pm.register_replacement(
            pattern,
            replacement,
357
            self.rmsnorm_matcher.inputs() + [scale],
358
359
360
361
362
363
364
365
366
367
368
            pm.fwd_only,
            pm_pass,
        )


class RMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
369
        symmetric: bool = True,
370
        is_e8m0: bool = False,
371
372
        has_col_major_scales: bool = True,
        is_tma_aligned: bool = True,
373
    ) -> None:
374
375
376
377
378
379
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
380
381
        self.has_col_major_scales = has_col_major_scales
        self.is_tma_aligned = is_tma_aligned
382
        super().__init__(
383
384
385
386
387
            epsilon,
            key,
            has_col_major_scales=self.has_col_major_scales,
            is_e8m0=is_e8m0,
            is_tma_aligned=is_tma_aligned,
388
        )
389

390
391
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
392
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
393
        ) -> tuple[torch.Tensor, torch.Tensor]:
394
            result_rms = self.rmsnorm_matcher(input, weight)
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
            result = torch.empty(
                result_rms.shape,
                device=result_rms.device,
                dtype=self.quant_matcher.quant_key.dtype,
            )
            assert scale is not None
            finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
            fp8_min = finfo.min
            fp8_max = finfo.max

            _, result, scale = auto_functionalized(
                self.quant_matcher.QUANT_OP,
                input=result_rms,
                output_q=result,
                output_s=scale,
                group_size=self.quant_matcher.quant_key.scale.group_shape[1],
                eps=1e-10,
                fp8_min=fp8_min,
                fp8_max=fp8_max,
                scale_ue8m0=self.quant_matcher.is_e8m0,
                dummy_is_scale_transposed=self.has_col_major_scales,
                dummy_is_tma_aligned=self.is_tma_aligned,
            )

419
420
            return result, scale

421
        def replacement(
422
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
423
        ) -> tuple[torch.Tensor, torch.Tensor]:
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
                group_size=self.group_shape[1],
439
                is_scale_transposed=self.has_col_major_scales,
440
441
442
443
444
            )

            # result, scale
            return at[1], at[2]

445
446
        scale = self.quant_matcher.empty_f32(1, 1)

447
448
449
        pm.register_replacement(
            pattern,
            replacement,
450
            self.rmsnorm_matcher.inputs() + [scale],
451
452
453
454
455
            pm.fwd_only,
            pm_pass,
        )


456
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
457
458
459
460
461
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
462
463
        symmetric: bool = True,
    ) -> None:
464
        scale = ScaleDesc(torch.float32, False, group_shape)
465
466
467
468
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
469
470
        super().__init__(epsilon, key)

471
472
473
474
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
475
            result_rms = self.rmsnorm_matcher(input, weight)
476
            # result, scale
477
            return self.quant_matcher(result_rms)  # type: ignore[no-any-return]
478

479
480
481
        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
482
483
484
485
486
487
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
488
489
490
491
492
493
494
495
496
497
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
498
499
500
501
502
503
504

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
505
            self.rmsnorm_matcher.inputs(),
506
507
            pm.fwd_only,
            pm_pass,
508
        )
509
510
511


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
512
513
514
515
516
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
517
518
        symmetric: bool = True,
    ) -> None:
519
        scale = ScaleDesc(torch.float32, False, group_shape)
520
521
522
523
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
524
525
        super().__init__(epsilon, key)

526
527
528
529
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
530
531
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
532

533
            return result, residual, scale
534

535
        def replacement(
536
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
537
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
538
539
540
541
542
543
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
544
545
546
547
548
549
550
551
552
553
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
554
555
556

            # result, residual, scale
            return at[1], at[3], at[2]
557

558
559
560
        pm.register_replacement(
            pattern,
            replacement,
561
            self.rmsnorm_matcher.inputs(),
562
563
            pm.fwd_only,
            pm_pass,
564
565
566
567
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
568
    """
569
570
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
571
572
    """

573
    @enable_fake_mode
574
    def __init__(self, config: VllmConfig) -> None:
575
576
577
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
578
579
            pass_name="rmsnorm_quant_fusion_pass"
        )
580

581
582
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
583
        for epsilon in [1e-5, 1e-6]:
584
            # Fuse fused_add_rms_norm + static fp8 quant
585
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
586
587
                self.patterns
            )
588

589
590
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
591
592

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
593
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
594
595
                self.patterns
            )
596

597
598
599
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

600
601
602
603
604
            # Only register group quant patterns on CUDA where the C++ op exists
            if current_platform.is_cuda():
                for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
                    for has_col_major_scales in [True, False]:
                        for is_e8m0 in [True, False]:
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
                            for is_tma_aligned in [False, True]:
                                # Fuse fused_add_rms_norm + fp8 group quant
                                FusedAddRMSNormGroupQuantPattern(
                                    epsilon,
                                    FP8_DTYPE,
                                    group_shape=group_shape,
                                    is_e8m0=is_e8m0,
                                    has_col_major_scales=has_col_major_scales,
                                    is_tma_aligned=is_tma_aligned,
                                ).register(self.patterns)

                                # Fuse rms_norm + fp8 group quant
                                RMSNormGroupQuantPattern(
                                    epsilon,
                                    FP8_DTYPE,
                                    group_shape=group_shape,
                                    is_e8m0=is_e8m0,
                                    has_col_major_scales=has_col_major_scales,
                                    is_tma_aligned=is_tma_aligned,
                                ).register(self.patterns)
625

626
        self.dump_patterns(config, self.patterns)
627

628
    @VllmInductorPass.time_and_log
629
    def __call__(self, graph: fx.Graph) -> None:
630
631
632
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

633
    def uuid(self) -> str:
634
635
        return self.hash_source(
            self,
636
            RMSNormGroupQuantPattern,
637
638
639
640
641
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
642
            FusedAddRMSNormGroupQuantPattern,
643
        )