fusion.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
    GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
    kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
17
from vllm.platforms import current_platform
18

19
from .inductor_pass import enable_fake_mode
20
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
21

22
logger = init_logger(__name__)
23
FP8_DTYPE = current_platform.fp8_dtype()
24
FP4_DTYPE = torch.uint8
25
26
27
28
29
30
31
32
33
34


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


35
36
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
37

38

39
40
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
41

42
QUANT_OPS: dict[QuantKey, OpOverload] = {
43
44
    kFp8StaticTensorSym:
    torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
45
    kFp8DynamicTensorSym:
46
    torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
47
    kFp8DynamicTokenSym:
48
    torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
49
}
50
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
51
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
    quant: QuantKey
    fused_add: bool

    def __str__(self):
        return (f"FusedQuantKey({self.quant}, with"
                f"{'' if self.fused_add else 'out'} residual)")


68
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
69
    FusedRMSQuantKey(kFp8StaticTensorSym, False):
70
    torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
71
    FusedRMSQuantKey(kFp8StaticTensorSym, True):
72
    torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
73
    FusedRMSQuantKey(kFp8DynamicTokenSym, False):
74
    torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
75
    FusedRMSQuantKey(kFp8DynamicTokenSym, True):
76
    torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
}


class RMSNormQuantPattern:

    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype

        assert key.quant in QUANT_OPS, \
            f"unsupported quantization scheme {key.quant}"
        self.QUANT_OP = QUANT_OPS[key.quant]

        assert key in FUSED_OPS, \
            f"unsupported fused rmsnorm+quant op for {key}"
        self.FUSED_OP = FUSED_OPS[key]


class RMSNormStaticQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 symmetric=True):
        fused_key = FusedRMSQuantKey(fused_add=False,
102
103
104
                                     quant=QuantKey(dtype=quant_dtype,
                                                    scale=kStaticTensorScale,
                                                    symmetric=symmetric))
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
        def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(RMS_OP,
                                      result=result_rms,
                                      input=input,
                                      weight=weight,
                                      epsilon=self.epsilon)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale)

            # result
            return at2[1]

        def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon)

            # result
            return at[1]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
                                pm_pass)


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 symmetric=True):
        key = FusedRMSQuantKey(fused_add=True,
157
158
159
                               quant=QuantKey(dtype=quant_dtype,
                                              scale=kStaticTensorScale,
                                              symmetric=symmetric))
160
161
        super().__init__(epsilon, key)

162
    def register(self, pm_pass: PatternMatcherPass):
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        def pattern(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at = auto_functionalized(RMS_ADD_OP,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     epsilon=self.epsilon)
            at1 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at[1],
                                      scale=scale)

            # result, residual
            return at1[1], at[2]

        def replacement(result: torch.Tensor, input: torch.Tensor,
                        residual: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon)

            # result, residual
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
208
        )
209
210
211
212
213
214
215


class RMSNormDynamicQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
216
                 group_shape: GroupShape = GroupShape.PER_TOKEN,
217
                 symmetric=True):
218
        scale = ScaleDesc(torch.float32, False, group_shape)
219
220
        key = FusedRMSQuantKey(fused_add=False,
                               quant=QuantKey(dtype=quant_dtype,
221
                                              scale=scale,
222
223
224
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

225
    def register(self, pm_pass: PatternMatcherPass):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

        def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(RMS_OP,
                                      result=result_rms,
                                      input=input,
                                      weight=weight,
                                      epsilon=self.epsilon)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale,
                                      scale_ub=None)

            # result, scale
            return at2[1], at2[2]

        def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon,
                                     scale_ub=None,
                                     residual=None)

            # result, scale
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
273
        )
274
275
276
277
278
279
280


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
281
                 group_shape: GroupShape = GroupShape.PER_TOKEN,
282
                 symmetric=True):
283
        scale = ScaleDesc(torch.float32, False, group_shape)
284
285
        key = FusedRMSQuantKey(fused_add=True,
                               quant=QuantKey(dtype=quant_dtype,
286
                                              scale=scale,
287
288
289
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

290
    def register(self, pm_pass: PatternMatcherPass):
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        def pattern(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at = auto_functionalized(RMS_ADD_OP,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     epsilon=self.epsilon)
            at1 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at[1],
                                      scale=scale,
                                      scale_ub=None)

            # result, residual, scale
            return at1[1], at[2], at1[2]

        def replacement(result: torch.Tensor, input: torch.Tensor,
                        residual: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon,
                                     scale_ub=None,
                                     residual=residual)

            # result, residual, scale
            return at[1], at[3], at[2]
323

324
325
326
327
328
329
330
        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]
331

332
333
334
335
336
337
        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
338
339
340
341
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
342
    """
343
344
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
345
346
    """

347
    @enable_fake_mode
348
    def __init__(self, config: VllmConfig):
349
350
351
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
352
            pass_name="rmsnorm_quant_fusion_pass")
353

354
355
356
357
        for epsilon in [1e-5, 1e-6]:
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon,
                                      FP8_DTYPE).register(self.patterns)
358

359
            # Fuse fused_add_rms_norm + static fp8 quant
360
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
361
                self.patterns)
362
363

            # Fuse rms_norm + dynamic per-token fp8 quant
364
365
            RMSNormDynamicQuantPattern(epsilon,
                                       FP8_DTYPE).register(self.patterns)
366
367

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
368
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
369
                self.patterns)
370

371
        self.dump_patterns(config, self.patterns)
372

373
    @VllmInductorPass.time_and_log
374
    def __call__(self, graph: fx.Graph):
375
376
377
378
379
380
381
382
383
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
        return self.hash_source(self, RMSNormQuantPattern,
                                RMSNormStaticQuantPattern,
                                RMSNormDynamicQuantPattern,
                                FusedAddRMSNormStaticQuantPattern,
                                FusedAddRMSNormDynamicQuantPattern)