test_fusion.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.config
9
import vllm.ir.ops
10
import vllm.plugins
11
from tests.compile.backend import TestBackend
12
from tests.utils import TestFP8Layer
13
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
14
15
16
17
18
19
20
21
22
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
    FUSED_OPS,
    FusedRMSQuantKey,
    RMSNormQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
23
24
25
26
27
28
29
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
30
from vllm.model_executor.kernels.linear import (
31
    AiterFp8BlockScaledMMKernel,
32
    ChannelWiseTorchFP8ScaledMMLinearKernel,
33
    CutlassFp8BlockScaledMMKernel,
34
    CutlassFP8ScaledMMLinearKernel,
35
36
    DeepGemmFp8BlockScaledMMKernel,
    FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
37
38
39
    FlashInferFP8ScaledMMLinearKernel,
    PerTensorTorchFP8ScaledMMLinearKernel,
    ROCmFP8ScaledMMLinearKernel,
40
    RowWiseTorchFP8ScaledMMLinearKernel,
41
42
    TritonFp8BlockScaledMMKernel,
    _KernelT,
43
)
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
from vllm.model_executor.layers.quantization.utils.quant_utils import (
46
    GroupShape,
47
    create_fp8_quant_key,
48
)
49
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
50
    cutlass_block_fp8_supported,
51
)
52
from vllm.platforms import current_platform
53
54
55
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
56

57
58
FP8_DTYPE = current_platform.fp8_dtype()

59
60
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

61
62
63
64
65
66
67
68
69
70
71
72
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
    (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
    (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
73
74
75
76
77
78
    # Blockwise group shapes
    (FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
    (CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
79
80
81
82
83
84
85
86
87
88
89
]

# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # ROCmFP8ScaledMMLinearKernel supports per-tensor only
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
90
91
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
    (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
    CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
    if current_platform.is_cuda()
    else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)

# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # Per-token with ROCmFP8ScaledMMLinearKernel
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
    # Per-token with RowWiseTorchFP8ScaledMMLinearKernel
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
110
111
    # Blockwise
    (AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
112
113
]

114
115

class TestModel(torch.nn.Module):
116
117
118
119
    def __init__(
        self,
        hidden_size: int,
        eps: float,
120
        force_kernel: type[_KernelT] | None,
121
        group_shape: GroupShape,
122
        dtype: torch.dtype,
123
124
        use_aiter_fusion: bool = False,
        use_aiter_quant: bool = False,
125
126
127
        *args,
        **kwargs,
    ):
128
        super().__init__(*args, **kwargs)
129
        self.fp8_linear_layers: list[torch.nn.Module]
130
        self.group_shape = group_shape
131
132
        self.use_aiter_quant_op = use_aiter_quant
        self.use_aiter_fusion = use_aiter_fusion
133
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
134
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
135

136
137
        # Determine if blockwise based on group_shape
        is_blockwise = group_shape.is_per_group()
138

139
        if is_blockwise:
140
141
142
            block_size = group_shape.col
            self.activation_quant_key = create_fp8_quant_key(
                static=False, group_shape=group_shape
143
            )
144
145
            self.weight_quant_key = create_fp8_quant_key(
                static=True, group_shape=GroupShape(block_size, block_size)
146
            )
147

148
        else:
149
            is_static = group_shape == GroupShape.PER_TENSOR
150
151
            self.activation_quant_key = create_fp8_quant_key(
                is_static, group_shape=group_shape
152
            )
153
154
            self.weight_quant_key = create_fp8_quant_key(
                static=True, group_shape=group_shape
155
            )
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.fp8_linear_layers = [
            TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.activation_quant_key,
                weight_quant_key=self.weight_quant_key,
                force_kernel=force_kernel,
                transpose_weights=use_aiter_fusion,
                input_dtype=dtype,
            )
            for _ in range(3)
        ]

        # Enable aiter quantization if requested
        for layer in self.fp8_linear_layers:
            layer.kernel.quant_fp8.use_aiter = use_aiter_quant
172

173
174
175
        self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
            0
        ].is_quant_fp8_enabled()
176

177
    def forward(self, x):
178
179
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
180
181
        y = self.norm[0](x)

182
        x2 = self.fp8_linear_layers[0](y)
183
184
185
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

186
        x3 = self.fp8_linear_layers[1](y2)
187

188
189
        y3, resid = self.norm[2](x3, resid)  # use resid here

190
        x4 = self.fp8_linear_layers[2](y3)
191
192
193

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
194

195
    def ops_in_model_before(self):
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        if self.group_shape.is_per_group():
            # Blockwise path
            if self.use_aiter_fusion and self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_group_quant_op()]
            if self.use_aiter_fusion:
                return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        else:
            if self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_per_token_quant_op()]

        # Common path
        return (
            [QUANT_OPS[self.activation_quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )
212

213
    def ops_in_model_after(self):
214
215
216
        if self.use_aiter_fusion:
            if self.group_shape.is_per_group():
                # Blockwise aiter fusion
217
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
218
219
220
                    AiterFusedAddRMSFp8GroupQuantPattern,
                    AiterRMSFp8GroupQuantPattern,
                )
221

222
223
224
225
226
227
                return [
                    AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                    AiterRMSFp8GroupQuantPattern.FUSED_OP,
                ]
            else:
                # Per-token aiter fusion
228
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
229
230
231
                    AiterFusedAddRMSNormDynamicQuantPattern,
                    AiterRMSNormDynamicQuantPattern,
                )
232

233
234
235
236
237
238
                return [
                    AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                    AiterRMSNormDynamicQuantPattern.FUSED_OP,
                ]

        # Regular fusion
239
        return [
240
241
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
242
243
        ]

244
    def ops_in_model_before_partial(self):
245
246
        return [torch.ops.vllm_ir.rms_norm] + (
            [RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
247
248
        )

249

250
251
252
253
254
255
256
257
258
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
259

260
261
262
263
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
264

265
266
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
267

268
269
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
270

271
272
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
273

274
275
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
276

277
278
279
280
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
281

282
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
283

284
285
286
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
287

288
    return backend, backend2
289
290


291
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
292
@pytest.mark.parametrize("hidden_size", [256])
293
@pytest.mark.parametrize("num_tokens", [257])
294
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
295
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
296
297
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
298
299
300
301
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
302
303
304
305
    dtype,
    hidden_size,
    num_tokens,
    eps,
306
    kernel_groupshape,
307
308
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
309
):
310
311
    force_kernel, group_shape = kernel_groupshape

312
313
314
315
316
317
318
319
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

320
321
322
323
324
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
325

326
    vllm_config = VllmConfig(
327
        model_config=ModelConfig(dtype=dtype),
328
        compilation_config=CompilationConfig(
329
            mode=CompilationMode.VLLM_COMPILE,
330
            custom_ops=custom_ops,
331
332
333
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
334
        ),
335
    )
336

337
338
339
340
    with (
        vllm.config.set_current_vllm_config(vllm_config),
        vllm_config.kernel_config.ir_op_priority.set_priority(),
    ):
341
342
343
344
345
346
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
347

348
        model = TestModel(
349
350
            hidden_size=hidden_size,
            eps=eps,
351
            force_kernel=force_kernel,
352
            group_shape=group_shape,
353
            dtype=dtype,
354
355
            use_aiter_fusion=False,
            use_aiter_quant=False,
356
        )
357

358
359
360
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
361
362
363
364
365
366
367
368
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
369
        if not enable_rms_norm_custom_op:
370
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
371
372
373
            # rms_norm is IR, not included
            # 6 = 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 6
374
            assert n_add_nodes(backend.graph_post_pass) == 2
375
376
377
378
379
380
381


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
382
    "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
383
384
385
386
387
388
389
390
391
392
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
393
    kernel_groupshape_quant: tuple,
394
395
    monkeypatch: pytest.MonkeyPatch,
):
396
    force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
397
398
399
400
401
402
403
404
405
406
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
407
408
409
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
            RocmAiterRMSNormQuantFusionPass,
        )
410
411

        m.setenv("VLLM_ROCM_USE_AITER", "1")
412

413
414
415
416
417
418
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

419
        fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
420

421
422
423
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
424
            force_kernel=force_kernel,
425
            group_shape=group_shape,
426
            dtype=dtype,
427
428
            use_aiter_fusion=True,  # Always use aiter fusion ops in aiter test
            use_aiter_quant=use_aiter_quant_op,  # Toggle aiter quantization
429
430
431
432
433
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )