"tests/entrypoints/serve/disagg/test_serving_tokens.py" did not exist on "9e9acce577cc8d6daf6db7aacc24a939c08391ca"
fusion.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, NamedTuple
4
5

import torch
6
7
import torch._inductor.pattern_matcher as pm
from torch import fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
11

12
from vllm.config import VllmConfig, get_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
18
19
20
21
22
23
    GroupShape,
    QuantKey,
    ScaleDesc,
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kNvfp4Quant,
    kStaticTensorScale,
)
24
from vllm.platforms import current_platform
25

26
from .inductor_pass import enable_fake_mode
27
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
28
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
29

30
logger = init_logger(__name__)
31
FP8_DTYPE = current_platform.fp8_dtype()
32
FP4_DTYPE = torch.uint8
33
34
35
36
37
38
39
40
41
42


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


43
44
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
45

46

47
48
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
49

50
QUANT_OPS: dict[QuantKey, OpOverload] = {
51
52
53
    kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default,  # noqa: E501
    kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default,  # noqa: E501
54
}
55
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
56
    QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
57
58
59
60
61
62
63
64


class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """
65

66
67
68
69
    quant: QuantKey
    fused_add: bool

    def __str__(self):
70
71
72
73
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )
74
75


76
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
77
78
79
80
81
82
83
84
85
86
87
88
    FusedRMSQuantKey(
        kFp8StaticTensorSym, False
    ): torch.ops._C.rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8StaticTensorSym, True
    ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, False
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
    FusedRMSQuantKey(
        kFp8DynamicTokenSym, True
    ): torch.ops._C.rms_norm_dynamic_per_token_quant.default,  # noqa: E501
89
90
91
92
93
94
95
}


class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
96
97
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
98

99
        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
100
101
        self.FUSED_OP = FUSED_OPS[key]

102
103
104
105
106
107
108
        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
        self.quant_matcher = MatcherQuantFP8(key.quant)

109
110

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
111
112
113
114
115
116
117
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
118
119
120
121
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
122
123
124
        def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]
125

126
127
128
129
        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)
130

131
132
133
            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
134
135
136
137
138
139
140
141
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
142
143
144
145
146

            # result
            return at[1]

        inputs = [
147
148
149
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
150
        ]
151
        pattern(*inputs)
152

153
        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
154
155
156


class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
157
158
159
160
161
162
163
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
164
165
        super().__init__(epsilon, key)

166
    def register(self, pm_pass: PatternMatcherPass):
167
168
169
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
170
            residual: torch.Tensor,
171
172
            scale: torch.Tensor,
        ):
173
174
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)
175

176
            return result, residual
177

178
179
180
        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
181
            residual: torch.Tensor,
182
183
            scale: torch.Tensor,
        ):
184
185
186
187
188
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
189
190
191
192
193
194
195
196
197
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )
198
199
200
201
202

            # result, residual
            return at[1], at[2]

        inputs = [
203
204
205
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
206
207
208
209
210
211
212
213
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
214
        )
215
216
217


class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
218
219
220
221
222
223
224
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
225
        scale = ScaleDesc(torch.float32, False, group_shape)
226
227
228
229
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
230
231
        super().__init__(epsilon, key)

232
    def register(self, pm_pass: PatternMatcherPass):
233
234
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
235
            # result, scale
236
            return self.quant_matcher(result_rms)
237

238
239
240
241
242
243
244
        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
245
246
247
248
249
250
251
252
253
254
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )
255
256
257
258
259
260
261

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
262
            self.rmsnorm_matcher.inputs(),
263
264
            pm.fwd_only,
            pm_pass,
265
        )
266
267
268


class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
269
270
271
272
273
274
275
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
276
        scale = ScaleDesc(torch.float32, False, group_shape)
277
278
279
280
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
281
282
        super().__init__(epsilon, key)

283
    def register(self, pm_pass: PatternMatcherPass):
284
285
286
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
287

288
            return result, residual, scale
289

290
        def replacement(
291
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
292
        ):
293
294
295
296
297
298
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
299
300
301
302
303
304
305
306
307
308
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )
309
310
311

            # result, residual, scale
            return at[1], at[3], at[2]
312

313
314
315
        pm.register_replacement(
            pattern,
            replacement,
316
            self.rmsnorm_matcher.inputs(),
317
318
            pm.fwd_only,
            pm_pass,
319
320
321
322
        )


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
323
    """
324
325
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
326
327
    """

328
    @enable_fake_mode
329
    def __init__(self, config: VllmConfig):
330
331
332
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
333
334
            pass_name="rmsnorm_quant_fusion_pass"
        )
335

336
337
        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
338
        for epsilon in [1e-5, 1e-6]:
339
            # Fuse fused_add_rms_norm + static fp8 quant
340
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
341
342
                self.patterns
            )
343

344
345
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
346
347

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
348
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
349
350
                self.patterns
            )
351

352
353
354
            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

355
        self.dump_patterns(config, self.patterns)
356

357
    @VllmInductorPass.time_and_log
358
    def __call__(self, graph: fx.Graph):
359
360
361
362
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
363
364
365
366
367
368
369
370
        return self.hash_source(
            self,
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
        )