fusion.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
    GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
    kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
17
from vllm.platforms import current_platform
18

19
from .inductor_pass import enable_fake_mode
20
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
21

22
logger = init_logger(__name__)
23
FP8_DTYPE = current_platform.fp8_dtype()
24
FP4_DTYPE = torch.uint8
25
26
27
28
29
30
31
32
33
34


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


35
36
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
37

38
39


40
41
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
42

43
QUANT_OPS: dict[QuantKey, OpOverload] = {
zhuwenwen's avatar
zhuwenwen committed
44
45
    # kFp8StaticTensorSym:
    # torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
46
    # kFp8DynamicTensorSym:
zhuwenwen's avatar
zhuwenwen committed
47
    # torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
48
    # kFp8DynamicTokenSym:
zhuwenwen's avatar
zhuwenwen committed
49
    # torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
50
}
51
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
52
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
    quant: QuantKey
    fused_add: bool

    def __str__(self):
        return (f"FusedQuantKey({self.quant}, with"
                f"{'' if self.fused_add else 'out'} residual)")


69
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
zhuwenwen's avatar
zhuwenwen committed
70
    # FusedRMSQuantKey(kFp8StaticTensorSym, False):
zhuwenwen's avatar
zhuwenwen committed
71
    # torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
72
    # FusedRMSQuantKey(kFp8StaticTensorSym, True):
zhuwenwen's avatar
zhuwenwen committed
73
    # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
74
    # FusedRMSQuantKey(kFp8DynamicTokenSym, False):
zhuwenwen's avatar
zhuwenwen committed
75
    # torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
76
    # FusedRMSQuantKey(kFp8DynamicTokenSym, True):
zhuwenwen's avatar
zhuwenwen committed
77
    # torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
}


class RMSNormQuantPattern:

    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype

        assert key.quant in QUANT_OPS, \
            f"unsupported quantization scheme {key.quant}"
        self.QUANT_OP = QUANT_OPS[key.quant]

        assert key in FUSED_OPS, \
            f"unsupported fused rmsnorm+quant op for {key}"
        self.FUSED_OP = FUSED_OPS[key]


class RMSNormStaticQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 symmetric=True):
        fused_key = FusedRMSQuantKey(fused_add=False,
103
104
105
                                     quant=QuantKey(dtype=quant_dtype,
                                                    scale=kStaticTensorScale,
                                                    symmetric=symmetric))
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
        def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(RMS_OP,
                                      result=result_rms,
                                      input=input,
                                      weight=weight,
                                      epsilon=self.epsilon)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale)

            # result
            return at2[1]

        def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon)

            # result
            return at[1]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
                                pm_pass)


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
                 symmetric=True):
        key = FusedRMSQuantKey(fused_add=True,
158
159
160
                               quant=QuantKey(dtype=quant_dtype,
                                              scale=kStaticTensorScale,
                                              symmetric=symmetric))
161
162
        super().__init__(epsilon, key)

163
    def register(self, pm_pass: PatternMatcherPass):
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        def pattern(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at = auto_functionalized(RMS_ADD_OP,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     epsilon=self.epsilon)
            at1 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at[1],
                                      scale=scale)

            # result, residual
            return at1[1], at[2]

        def replacement(result: torch.Tensor, input: torch.Tensor,
                        residual: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon)

            # result, residual
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
209
        )
210
211
212
213
214
215
216


class RMSNormDynamicQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
217
                 group_shape: GroupShape = GroupShape.PER_TOKEN,
218
                 symmetric=True):
219
        scale = ScaleDesc(torch.float32, False, group_shape)
220
221
        key = FusedRMSQuantKey(fused_add=False,
                               quant=QuantKey(dtype=quant_dtype,
222
                                              scale=scale,
223
224
225
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

226
    def register(self, pm_pass: PatternMatcherPass):
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

        def pattern(result: torch.Tensor, result_rms: torch.Tensor,
                    input: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(RMS_OP,
                                      result=result_rms,
                                      input=input,
                                      weight=weight,
                                      epsilon=self.epsilon)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale,
                                      scale_ub=None)

            # result, scale
            return at2[1], at2[2]

        def replacement(result: torch.Tensor, result_rms: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon,
                                     scale_ub=None,
                                     residual=None)

            # result, scale
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
274
        )
275
276
277
278
279
280
281


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):

    def __init__(self,
                 epsilon: float,
                 quant_dtype: torch.dtype,
282
                 group_shape: GroupShape = GroupShape.PER_TOKEN,
283
                 symmetric=True):
284
        scale = ScaleDesc(torch.float32, False, group_shape)
285
286
        key = FusedRMSQuantKey(fused_add=True,
                               quant=QuantKey(dtype=quant_dtype,
287
                                              scale=scale,
288
289
290
                                              symmetric=symmetric))
        super().__init__(epsilon, key)

291
    def register(self, pm_pass: PatternMatcherPass):
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

        def pattern(result: torch.Tensor, input: torch.Tensor,
                    residual: torch.Tensor, weight: torch.Tensor,
                    scale: torch.Tensor):
            at = auto_functionalized(RMS_ADD_OP,
                                     input=input,
                                     residual=residual,
                                     weight=weight,
                                     epsilon=self.epsilon)
            at1 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at[1],
                                      scale=scale,
                                      scale_ub=None)

            # result, residual, scale
            return at1[1], at[2], at1[2]

        def replacement(result: torch.Tensor, input: torch.Tensor,
                        residual: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     weight=weight,
                                     scale=scale,
                                     epsilon=self.epsilon,
                                     scale_ub=None,
                                     residual=residual)

            # result, residual, scale
            return at[1], at[3], at[2]
324

325
326
327
328
329
330
331
        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1)  # scale
        ]
332

333
334
335
336
337
338
        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
339
340
341
342
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
343
    """
344
345
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
346
347
    """

348
    @enable_fake_mode
349
    def __init__(self, config: VllmConfig):
350
351
352
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
353
            pass_name="rmsnorm_quant_fusion_pass")
354

355
356
357
358
        for epsilon in [1e-5, 1e-6]:
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon,
                                      FP8_DTYPE).register(self.patterns)
359

360
            # Fuse fused_add_rms_norm + static fp8 quant
361
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
362
                self.patterns)
363
364

            # Fuse rms_norm + dynamic per-token fp8 quant
365
366
            RMSNormDynamicQuantPattern(epsilon,
                                       FP8_DTYPE).register(self.patterns)
367
368

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
369
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
370
                self.patterns)
371

372
        self.dump_patterns(config, self.patterns)
373

374
    @VllmInductorPass.time_and_log
375
    def __call__(self, graph: fx.Graph):
376
377
378
379
380
381
382
383
384
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
        return self.hash_source(self, RMSNormQuantPattern,
                                RMSNormStaticQuantPattern,
                                RMSNormDynamicQuantPattern,
                                FusedAddRMSNormStaticQuantPattern,
                                FusedAddRMSNormDynamicQuantPattern)