test_fusion.py 15.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.config
9
import vllm.ir.ops
10
import vllm.plugins
11
from tests.compile.backend import TestBackend
12
from tests.utils import TestFP8Layer
13
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
14
15
16
17
18
19
20
21
22
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
    FUSED_OPS,
    FusedRMSQuantKey,
    RMSNormQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
23
24
25
26
27
28
29
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
30
from vllm.model_executor.kernels.linear import (
31
    AiterFp8BlockScaledMMKernel,
32
    ChannelWiseTorchFP8ScaledMMLinearKernel,
33
    CutlassFp8BlockScaledMMKernel,
34
    CutlassFP8ScaledMMLinearKernel,
35
36
    DeepGemmFp8BlockScaledMMKernel,
    FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
37
38
39
    FlashInferFP8ScaledMMLinearKernel,
    PerTensorTorchFP8ScaledMMLinearKernel,
    ROCmFP8ScaledMMLinearKernel,
40
    RowWiseTorchFP8ScaledMMLinearKernel,
41
42
    TritonFp8BlockScaledMMKernel,
    _KernelT,
43
)
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
from vllm.model_executor.layers.quantization.utils.quant_utils import (
46
    GroupShape,
47
    create_fp8_quant_key,
48
)
49
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
50
    cutlass_block_fp8_supported,
51
)
52
from vllm.platforms import current_platform
53
from vllm.utils.deep_gemm import (
54
    is_deep_gemm_e8m0_used,
55
56
    is_deep_gemm_supported,
)
57

58
59
FP8_DTYPE = current_platform.fp8_dtype()

60
61
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

62
63
64
65
66
67
68
69
70
71
72
73
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
    (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
    (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
74
75
76
77
78
79
    # Blockwise group shapes
    (FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
    (CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
80
81
82
83
84
85
86
87
88
89
90
]

# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # ROCmFP8ScaledMMLinearKernel supports per-tensor only
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
91
92
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
    CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
    if current_platform.is_cuda()
    else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)

# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # Per-token with ROCmFP8ScaledMMLinearKernel
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
    # Per-token with RowWiseTorchFP8ScaledMMLinearKernel
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
111
112
    # Blockwise
    (AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
113
114
]

115
116

class TestModel(torch.nn.Module):
117
118
119
120
    def __init__(
        self,
        hidden_size: int,
        eps: float,
121
        force_kernel: type[_KernelT] | None,
122
        group_shape: GroupShape,
123
        dtype: torch.dtype,
124
125
        use_aiter_fusion: bool = False,
        use_aiter_quant: bool = False,
126
127
128
        *args,
        **kwargs,
    ):
129
        super().__init__(*args, **kwargs)
130
        self.fp8_linear_layers: list[torch.nn.Module]
131
        self.group_shape = group_shape
132
133
        self.use_aiter_quant_op = use_aiter_quant
        self.use_aiter_fusion = use_aiter_fusion
134
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
135
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
136

137
138
        # Determine if blockwise based on group_shape
        is_blockwise = group_shape.is_per_group()
139

140
        if is_blockwise:
141
142
143
            block_size = group_shape.col
            self.activation_quant_key = create_fp8_quant_key(
                static=False, group_shape=group_shape
144
            )
145
146
            self.weight_quant_key = create_fp8_quant_key(
                static=True, group_shape=GroupShape(block_size, block_size)
147
            )
148

149
        else:
150
            is_static = group_shape == GroupShape.PER_TENSOR
151
152
            self.activation_quant_key = create_fp8_quant_key(
                is_static, group_shape=group_shape
153
            )
154
155
            self.weight_quant_key = create_fp8_quant_key(
                static=True, group_shape=group_shape
156
            )
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        self.fp8_linear_layers = [
            TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.activation_quant_key,
                weight_quant_key=self.weight_quant_key,
                force_kernel=force_kernel,
                transpose_weights=use_aiter_fusion,
                input_dtype=dtype,
            )
            for _ in range(3)
        ]

        # Enable aiter quantization if requested
        for layer in self.fp8_linear_layers:
            layer.kernel.quant_fp8.use_aiter = use_aiter_quant
173

174
175
176
        self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
            0
        ].is_quant_fp8_enabled()
177

178
    def forward(self, x):
179
180
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
181
182
        y = self.norm[0](x)

183
        x2 = self.fp8_linear_layers[0](y)
184
185
186
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

187
        x3 = self.fp8_linear_layers[1](y2)
188

189
190
        y3, resid = self.norm[2](x3, resid)  # use resid here

191
        x4 = self.fp8_linear_layers[2](y3)
192
193
194

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
195

196
    def ops_in_model_before(self):
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        if self.group_shape.is_per_group():
            # Blockwise path
            if self.use_aiter_fusion and self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_group_quant_op()]
            if self.use_aiter_fusion:
                return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        else:
            if self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_per_token_quant_op()]

        # Common path
        return (
            [QUANT_OPS[self.activation_quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )
213

214
    def ops_in_model_after(self):
215
216
217
        if self.use_aiter_fusion:
            if self.group_shape.is_per_group():
                # Blockwise aiter fusion
218
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
219
220
221
                    AiterFusedAddRMSFp8GroupQuantPattern,
                    AiterRMSFp8GroupQuantPattern,
                )
222

223
224
225
226
227
228
                return [
                    AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                    AiterRMSFp8GroupQuantPattern.FUSED_OP,
                ]
            else:
                # Per-token aiter fusion
229
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
230
231
232
                    AiterFusedAddRMSNormDynamicQuantPattern,
                    AiterRMSNormDynamicQuantPattern,
                )
233

234
235
236
237
238
239
                return [
                    AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                    AiterRMSNormDynamicQuantPattern.FUSED_OP,
                ]

        # Regular fusion
240
        return [
241
242
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
243
244
        ]

245
    def ops_in_model_before_partial(self):
246
247
        return [torch.ops.vllm_ir.rms_norm] + (
            [RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
248
249
        )

250

251
252
253
254
255
256
257
258
259
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
260

261
262
263
264
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
265

266
267
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
268

269
270
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
271

272
273
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
274

275
276
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
277

278
279
280
281
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
282

283
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
284

285
286
287
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
288

289
    return backend, backend2
290
291


292
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
293
@pytest.mark.parametrize("hidden_size", [256])
294
@pytest.mark.parametrize("num_tokens", [257])
295
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
296
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
297
298
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
299
300
301
302
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
303
304
305
306
    dtype,
    hidden_size,
    num_tokens,
    eps,
307
    kernel_groupshape,
308
309
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
310
):
311
312
    force_kernel, group_shape = kernel_groupshape

313
314
315
316
317
318
319
320
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    # TODO(quant-rms-fusion): DeepGEMM UE8M0 activation quant on B200 lowers
    # to a packed int32-scale op (per_token_group_quant_fp8_packed_for_deepgemm),
    # but the rms+quant fusion pattern only matches the fp32-scale variant, so
    # the fused output gets a mismatched scale layout and produces NaN. Only
    # reproduces on bf16 (DeepGEMM UE8M0 on B200 is bf16-only).
    # To re-enable: make rms_norm_per_block_quant emit packed UE8M0 scales
    # and extend the fusion pattern to rewrite the packed activation quant.
    deepgemm_kernels = (
        DeepGemmFp8BlockScaledMMKernel,
        FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
    )
    if (
        dtype == torch.bfloat16
        and force_kernel in deepgemm_kernels
        and is_deep_gemm_e8m0_used()
    ):
        pytest.skip(
            "rms+quant fusion does not yet match the packed UE8M0 DeepGEMM path"
        )

341
342
343
344
345
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
346

347
    vllm_config = VllmConfig(
348
        model_config=ModelConfig(dtype=dtype),
349
        compilation_config=CompilationConfig(
350
            mode=CompilationMode.VLLM_COMPILE,
351
            custom_ops=custom_ops,
352
353
354
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
355
        ),
356
    )
357

358
359
360
361
    with (
        vllm.config.set_current_vllm_config(vllm_config),
        vllm_config.kernel_config.ir_op_priority.set_priority(),
    ):
362
363
364
365
366
367
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
368

369
        model = TestModel(
370
371
            hidden_size=hidden_size,
            eps=eps,
372
            force_kernel=force_kernel,
373
            group_shape=group_shape,
374
            dtype=dtype,
375
376
            use_aiter_fusion=False,
            use_aiter_quant=False,
377
        )
378

379
380
381
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
382
383
384
385
386
387
388
389
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
390
        if not enable_rms_norm_custom_op:
391
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
392
393
394
            # rms_norm is IR, not included
            # 6 = 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 6
395
            assert n_add_nodes(backend.graph_post_pass) == 2
396
397
398
399
400
401
402


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
403
    "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
404
405
406
407
408
409
410
411
412
413
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
414
    kernel_groupshape_quant: tuple,
415
416
    monkeypatch: pytest.MonkeyPatch,
):
417
    force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
418
419
420
421
422
423
424
425
426
427
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
428
429
430
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
            RocmAiterRMSNormQuantFusionPass,
        )
431
432

        m.setenv("VLLM_ROCM_USE_AITER", "1")
433

434
435
436
437
438
439
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

440
        fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
441

442
443
444
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
445
            force_kernel=force_kernel,
446
            group_shape=group_shape,
447
            dtype=dtype,
448
449
            use_aiter_fusion=True,  # Always use aiter fusion ops in aiter test
            use_aiter_quant=use_aiter_quant_op,  # Toggle aiter quantization
450
451
452
453
454
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )