"vllm/vscode:/vscode.git/clone" did not exist on "98b09ddc2761545f3164d930b143f84737b1ab43"
fusion.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
18
19
20
21
22
23
    GroupShape,
    QuantKey,
    ScaleDesc,
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
    kStaticTensorScale,
)
24
from vllm.platforms import current_platform
25

26
from .inductor_pass import enable_fake_mode
27
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
28

29
logger = init_logger(__name__)
30
FP8_DTYPE = current_platform.fp8_dtype()
31
FP4_DTYPE = torch.uint8
32
33
34
35
36
37
38
39
40
41


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


42
43
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
44

45

46
47
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
48

49
QUANT_OPS: dict[QuantKey, OpOverload] = {
50
51
52
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
53
}
54
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
55
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
56
57
58
59
60
61
62
63


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
64

65
66
67
68
    quant: QuantKey
    fused_add: bool

    def __str__(self):
69
70
71
72
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
73
74


75
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
76
77
78
79
80
81
82
83
84
85
86
87
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
88
89
90
91
92
93
94
95
}


class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype

96
        assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
97
98
        self.QUANT_OP = QUANT_OPS[key.quant]

99
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
100
101
102
103
        self.FUSED_OP = FUSED_OPS[key]


class RMSNormStaticQuantPattern(RMSNormQuantPattern):
104
105
106
107
108
109
110
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
111
112
113
114
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        def pattern(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at1 = auto_functionalized(
                RMS_OP,
                result=result_rms,
                input=input,
                weight=weight,
                epsilon=self.epsilon,
            )
            at2 = auto_functionalized(
                self.QUANT_OP, result=result, input=at1[1], scale=scale
            )
132
133
134
135

            # result
            return at2[1]

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        def replacement(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
151
152
153
154
155
156
157
158
159

            # result
            return at[1]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
160
            empty_fp32(1, 1),  # scale
161
162
        ]

163
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
164
165
166


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
167
168
169
170
171
172
173
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
174
175
        super().__init__(epsilon, key)

176
    def register(self, pm_pass: PatternMatcherPass):
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        def pattern(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                RMS_ADD_OP,
                input=input,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
            )
            at1 = auto_functionalized(
                self.QUANT_OP, result=result, input=at[1], scale=scale
            )
194
195
196
197

            # result, residual
            return at1[1], at[2]

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        def replacement(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
214
215
216
217
218
219
220
221
222

            # result, residual
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
223
            empty_fp32(1, 1),  # scale
224
225
226
227
228
229
230
231
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
232
        )
233
234
235


class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
236
237
238
239
240
241
242
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
243
        scale = ScaleDesc(torch.float32, False, group_shape)
244
245
246
247
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
248
249
        super().__init__(epsilon, key)

250
    def register(self, pm_pass: PatternMatcherPass):
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        def pattern(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at1 = auto_functionalized(
                RMS_OP,
                result=result_rms,
                input=input,
                weight=weight,
                epsilon=self.epsilon,
            )
            at2 = auto_functionalized(
                self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
            )
268
269
270
271

            # result, scale
            return at2[1], at2[2]

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        def replacement(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
289
290
291
292
293
294
295
296
297

            # result, scale
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
298
            empty_fp32(1, 1),  # scale
299
300
301
302
303
304
305
306
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
307
        )
308
309
310


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
311
312
313
314
315
316
317
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
318
        scale = ScaleDesc(torch.float32, False, group_shape)
319
320
321
322
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
323
324
        super().__init__(epsilon, key)

325
    def register(self, pm_pass: PatternMatcherPass):
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        def pattern(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                RMS_ADD_OP,
                input=input,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
            )
            at1 = auto_functionalized(
                self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
            )
343
344
345
346

            # result, residual, scale
            return at1[1], at[2], at1[2]

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        def replacement(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
364
365
366

            # result, residual, scale
            return at[1], at[3], at[2]
367

368
369
370
371
372
        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
373
            empty_fp32(1, 1),  # scale
374
        ]
375

376
377
378
379
380
381
        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
382
383
384
385
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
386
    """
387
388
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
389
390
    """

391
    @enable_fake_mode
392
    def __init__(self, config: VllmConfig):
393
394
395
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
396
397
            pass_name="rmsnorm_quant_fusion_pass"
        )
398

399
400
        for epsilon in [1e-5, 1e-6]:
            # Fuse rms_norm + static fp8 quant
401
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
402

403
            # Fuse fused_add_rms_norm + static fp8 quant
404
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
405
406
                self.patterns
            )
407
408

            # Fuse rms_norm + dynamic per-token fp8 quant
409
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
410
411

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
412
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
413
414
                self.patterns
            )
415

416
        self.dump_patterns(config, self.patterns)
417

418
    @VllmInductorPass.time_and_log
419
    def __call__(self, graph: fx.Graph):
420
421
422
423
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
424
425
426
427
428
429
430
431
        return self.hash_source(
            self,
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
        )