fusion.py 19.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig, get_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
    GroupShape,
    QuantKey,
    ScaleDesc,
18
19
    kFp8Dynamic64Sym,
    kFp8Dynamic128Sym,
20
21
22
23
24
25
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
    kStaticTensorScale,
)
26
from vllm.platforms import current_platform
27

28
from .inductor_pass import enable_fake_mode
29
30
31
32
33
from .matcher_utils import (
    MatcherFusedAddRMSNorm,
    MatcherQuantFP8,
    MatcherRMSNorm,
)
34
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
35

36
logger = init_logger(__name__)
37
FP8_DTYPE = current_platform.fp8_dtype()
38
FP4_DTYPE = torch.uint8
39
40


41
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
42
43
44
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


45
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
46
47
48
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


49
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
50
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
51

52

53
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
54
55
56
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")


57
58
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
59

60
QUANT_OPS: dict[QuantKey, OpOverload] = {
61
62
63
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
64
}
65
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
66
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
67
68
69
if current_platform.is_cuda():
    QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
    QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default  # noqa: E501
70
71
72
73
74
75
76
77


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
78

79
80
81
    quant: QuantKey
    fused_add: bool

82
    def __str__(self) -> str:
83
84
85
86
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
87
88


89
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
90
91
92
93
94
95
96
97
98
99
100
101
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
102
103
104
105
106
107
108
109
110
111
112
113
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic128Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, False
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8Dynamic64Sym, True
    ): torch.ops._C.rms_norm_per_block_quant.default,  # noqa: E501
114
115
116
117
}


class RMSNormQuantPattern:
118
119
120
121
122
123
    def __init__(
        self,
        epsilon: float,
        key: FusedRMSQuantKey,
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
124
    ) -> None:
125
126
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
127
128
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
129

130
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
131
132
        self.FUSED_OP = FUSED_OPS[key]

133
134
135
136
137
        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
138
        self.quant_matcher = MatcherQuantFP8(
139
            key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
140
        )
141

142
143

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
144
145
146
    def __init__(
        self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
    ) -> None:
147
148
149
150
151
152
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
153
154
        super().__init__(epsilon, fused_key)

155
    def register(self, pm_pass: PatternMatcherPass) -> None:
156
        # Cannot use methods, as the self argument affects tracing
157
158
159
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> torch.Tensor:
160
161
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]
162

163
164
165
        def replacement(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> torch.Tensor:
166
167
168
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)
169

170
171
172
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
173
174
175
176
177
178
179
180
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
181
182
183
184
185

            # result
            return at[1]

        inputs = [
186
187
188
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
189
        ]
190
        pattern(*inputs)
191

192
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
193
194
195


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
196
197
198
    def __init__(
        self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
    ) -> None:
199
200
201
202
203
204
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
205
206
        super().__init__(epsilon, key)

207
    def register(self, pm_pass: PatternMatcherPass) -> None:
208
209
210
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
211
            residual: torch.Tensor,
212
            scale: torch.Tensor,
213
        ) -> tuple[torch.Tensor, torch.Tensor]:
214
215
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
216

217
            return result, residual
218

219
220
221
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
222
            residual: torch.Tensor,
223
            scale: torch.Tensor,
224
        ) -> tuple[torch.Tensor, torch.Tensor]:
225
226
227
228
229
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
230
231
232
233
234
235
236
237
238
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
239
240
241
242
243

            # result, residual
            return at[1], at[2]

        inputs = [
244
245
246
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
247
248
249
250
251
252
253
254
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
255
        )
256
257


258
259
260
261
262
263
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
264
        symmetric: bool = True,
265
266
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
267
    ) -> None:
268
269
270
271
272
273
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
274
275
276
277
278
        self.has_col_major_scales = has_col_major_scales
        self.is_e8m0 = is_e8m0
        super().__init__(
            epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
        )
279

280
281
282
283
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
284
285
286
287
288
289
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
            return result, residual, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
290
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
291
292
293
294
295
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
296
            scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
297
298
299
300
301
302
303
304
305
306
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
                group_size=self.group_shape[1],
307
                is_scale_transposed=self.has_col_major_scales,
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )


class RMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
328
        symmetric: bool = True,
329
330
        has_col_major_scales: bool = False,
        is_e8m0: bool = False,
331
    ) -> None:
332
333
334
335
336
337
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
338
339
340
        super().__init__(
            epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
        )
341

342
343
344
345
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
346
347
348
349
            result_rms = self.rmsnorm_matcher(input, weight)
            result, scale = self.quant_matcher(result_rms)
            return result, scale

350
351
352
        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
353
354
355
356
357
358
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(
359
                input, transposed=self.quant_matcher.has_col_major_scales
360
361
362
363
364
365
366
367
368
369
370
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
                group_size=self.group_shape[1],
371
                is_scale_transposed=self.quant_matcher.has_col_major_scales,
372
373
374
375
376
377
378
379
380
381
382
383
384
385
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )


386
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
387
388
389
390
391
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
392
393
        symmetric: bool = True,
    ) -> None:
394
        scale = ScaleDesc(torch.float32, False, group_shape)
395
396
397
398
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
399
400
        super().__init__(epsilon, key)

401
402
403
404
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
405
            result_rms = self.rmsnorm_matcher(input, weight)
406
            # result, scale
407
            return self.quant_matcher(result_rms)  # type: ignore[no-any-return]
408

409
410
411
        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
412
413
414
415
416
417
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
418
419
420
421
422
423
424
425
426
427
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
428
429
430
431
432
433
434

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
435
            self.rmsnorm_matcher.inputs(),
436
437
            pm.fwd_only,
            pm_pass,
438
        )
439
440
441


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
442
443
444
445
446
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
447
448
        symmetric: bool = True,
    ) -> None:
449
        scale = ScaleDesc(torch.float32, False, group_shape)
450
451
452
453
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
454
455
        super().__init__(epsilon, key)

456
457
458
459
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
460
461
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
462

463
            return result, residual, scale
464

465
        def replacement(
466
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
467
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
468
469
470
471
472
473
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
474
475
476
477
478
479
480
481
482
483
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
484
485
486

            # result, residual, scale
            return at[1], at[3], at[2]
487

488
489
490
        pm.register_replacement(
            pattern,
            replacement,
491
            self.rmsnorm_matcher.inputs(),
492
493
            pm.fwd_only,
            pm_pass,
494
495
496
497
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
498
    """
499
500
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
501
502
    """

503
    @enable_fake_mode
504
    def __init__(self, config: VllmConfig) -> None:
505
506
507
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
508
509
            pass_name="rmsnorm_quant_fusion_pass"
        )
510

511
512
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
513
        for epsilon in [1e-5, 1e-6]:
514
            # Fuse fused_add_rms_norm + static fp8 quant
515
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
516
517
                self.patterns
            )
518

519
520
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
521
522

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
523
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
524
525
                self.patterns
            )
526

527
528
529
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
            # Only register group quant patterns on CUDA where the C++ op exists
            if current_platform.is_cuda():
                for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
                    for has_col_major_scales in [True, False]:
                        for is_e8m0 in [True, False]:
                            # Fuse fused_add_rms_norm + fp8 group quant
                            FusedAddRMSNormGroupQuantPattern(
                                epsilon,
                                FP8_DTYPE,
                                group_shape=group_shape,
                                has_col_major_scales=has_col_major_scales,
                                is_e8m0=is_e8m0,
                            ).register(self.patterns)

                            # Fuse rms_norm + fp8 group quant
                            RMSNormGroupQuantPattern(
                                epsilon,
                                FP8_DTYPE,
                                group_shape=group_shape,
                                has_col_major_scales=has_col_major_scales,
                                is_e8m0=is_e8m0,
                            ).register(self.patterns)

553
        self.dump_patterns(config, self.patterns)
554

555
    @VllmInductorPass.time_and_log
556
    def __call__(self, graph: fx.Graph) -> None:
557
558
559
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

560
    def uuid(self) -> str:
561
562
        return self.hash_source(
            self,
563
            RMSNormGroupQuantPattern,
564
565
566
567
568
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
569
            FusedAddRMSNormGroupQuantPattern,
570
        )