test_fusion.py 15.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.config
9
import vllm.plugins
10
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
11
12
13
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
14
from vllm.compilation.noop_elimination import NoOpEliminationPass
15
from vllm.compilation.post_cleanup import PostCleanupPass
16
17
18
19
20
21
22
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
    CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
    FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
    ChannelWiseTorchFP8ScaledMMLinearKernel,
    PerTensorTorchFP8ScaledMMLinearKernel,
    RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
    ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (  # noqa: E501
    FP8ScaledMMLinearKernel,
40
)
41
from vllm.model_executor.layers.quantization.utils.quant_utils import (
42
43
44
45
    GroupShape,
    QuantKey,
    ScaleDesc,
)
46
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
47
    cutlass_block_fp8_supported,
48
)
49
from vllm.platforms import current_platform
50
51
52
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
53

54
from ..utils import TestBlockFP8Layer, TestFP8Layer
55
56
from .backend import TestBackend

57
58
FP8_DTYPE = current_platform.fp8_dtype()

59
60
61
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
    (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
    (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # ROCmFP8ScaledMMLinearKernel supports per-tensor only
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
    CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
    if current_platform.is_cuda()
    else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)

# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # Per-token with ROCmFP8ScaledMMLinearKernel
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
    # Per-token with RowWiseTorchFP8ScaledMMLinearKernel
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Blockwise (no kernel abstraction)
    (None, GroupShape(1, 128), True),
]

112
113

class TestModel(torch.nn.Module):
114
115
116
117
    def __init__(
        self,
        hidden_size: int,
        eps: float,
118
        force_kernel: FP8ScaledMMLinearKernel | None,
119
        group_shape: GroupShape,
120
121
        use_aiter_fusion: bool = False,
        use_aiter_quant: bool = False,
122
123
124
        *args,
        **kwargs,
    ):
125
        super().__init__(*args, **kwargs)
126
        self.fp8_linear_layers: list[torch.nn.Module]
127
        self.group_shape = group_shape
128
129
        self.use_aiter_quant_op = use_aiter_quant
        self.use_aiter_fusion = use_aiter_fusion
130
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
131
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
132

133
134
        # Determine if blockwise based on group_shape
        is_blockwise = group_shape.is_per_group()
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        if is_blockwise:
            act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
            )
            self.fp8_linear_layers = [
                TestBlockFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    group_shape=group_shape,
                    cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                    use_aiter_and_is_supported=use_aiter_quant,
                    transpose_weights=use_aiter_fusion,
                )
                for _ in range(3)
            ]
151

152
153
154
155
            self.enable_quant_fp8_custom_op = (
                False
                if use_aiter_quant
                else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
156
            )
157

158
        else:
159
160
161
162
163
            is_static = group_shape == GroupShape.PER_TENSOR
            act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
            w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
164
            )
165
166
            self.weight_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
167
            )
168
169
170
171
172
173
            self.fp8_linear_layers = [
                TestFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    activation_quant_key=self.activation_quant_key,
                    weight_quant_key=self.weight_quant_key,
                    force_kernel=force_kernel,
174
                )
175
176
                for _ in range(3)
            ]
177

178
179
180
181
182
183
184
            # Enable aiter quantization if requested
            for layer in self.fp8_linear_layers:
                layer.kernel.quant_fp8.use_aiter = use_aiter_quant

            self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
                0
            ].is_quant_fp8_enabled()
185

186
    def forward(self, x):
187
188
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
189
190
        y = self.norm[0](x)

191
        x2 = self.fp8_linear_layers[0](y)
192
193
194
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

195
        x3 = self.fp8_linear_layers[1](y2)
196

197
198
        y3, resid = self.norm[2](x3, resid)  # use resid here

199
        x4 = self.fp8_linear_layers[2](y3)
200
201
202

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
203

204
    def ops_in_model_before(self):
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        if self.group_shape.is_per_group():
            # Blockwise path
            if self.use_aiter_fusion and self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_group_quant_op()]
            if self.use_aiter_fusion:
                return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        else:
            if self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_per_token_quant_op()]

        # Common path
        return (
            [QUANT_OPS[self.activation_quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )
221

222
    def ops_in_model_after(self):
223
224
225
226
227
228
229
        if self.use_aiter_fusion:
            if self.group_shape.is_per_group():
                # Blockwise aiter fusion
                from vllm.compilation.rocm_aiter_fusion import (
                    AiterFusedAddRMSFp8GroupQuantPattern,
                    AiterRMSFp8GroupQuantPattern,
                )
230

231
232
233
234
235
236
237
238
239
240
                return [
                    AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                    AiterRMSFp8GroupQuantPattern.FUSED_OP,
                ]
            else:
                # Per-token aiter fusion
                from vllm.compilation.rocm_aiter_fusion import (
                    AiterFusedAddRMSNormDynamicQuantPattern,
                    AiterRMSNormDynamicQuantPattern,
                )
241

242
243
244
245
246
247
                return [
                    AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                    AiterRMSNormDynamicQuantPattern.FUSED_OP,
                ]

        # Regular fusion
248
        return [
249
250
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
251
252
        ]

253
254
255
256
257
258
259
    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

260

261
262
263
264
265
266
267
268
269
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
270

271
272
273
274
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
275

276
277
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
278

279
280
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
281

282
283
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
284

285
286
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
287

288
289
290
291
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
292

293
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
294

295
296
297
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
298

299
    return backend, backend2
300
301


302
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
303
@pytest.mark.parametrize("hidden_size", [256])
304
@pytest.mark.parametrize("num_tokens", [257])
305
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
306
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
307
308
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
309
310
311
312
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
313
314
315
316
    dtype,
    hidden_size,
    num_tokens,
    eps,
317
    kernel_groupshape,
318
319
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
320
):
321
322
    force_kernel, group_shape = kernel_groupshape

323
324
325
326
327
328
329
330
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

331
332
333
334
335
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
336

337
    vllm_config = VllmConfig(
338
        model_config=ModelConfig(dtype=dtype),
339
        compilation_config=CompilationConfig(
340
            mode=CompilationMode.VLLM_COMPILE,
341
            custom_ops=custom_ops,
342
343
344
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
345
        ),
346
    )
347

348
349
350
351
352
353
354
    with vllm.config.set_current_vllm_config(vllm_config):
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
355

356
        model = TestModel(
357
358
            hidden_size=hidden_size,
            eps=eps,
359
            force_kernel=force_kernel,
360
            group_shape=group_shape,
361
362
            use_aiter_fusion=False,
            use_aiter_quant=False,
363
        )
364

365
366
367
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
368
369
370
371
372
373
374
375
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
376
        if not enable_rms_norm_custom_op:
377
378
379
380
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2
381
382
383
384
385
386
387


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
388
    "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
389
390
391
392
393
394
395
396
397
398
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
399
    kernel_groupshape_quant: tuple,
400
401
    monkeypatch: pytest.MonkeyPatch,
):
402
    force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
403
404
405
406
407
408
409
410
411
412
413
414
415
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
        from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass

        m.setenv("VLLM_ROCM_USE_AITER", "1")
416

417
418
419
420
421
422
423
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
424

425
426
427
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
428
            force_kernel=force_kernel,
429
            group_shape=group_shape,
430
431
            use_aiter_fusion=True,  # Always use aiter fusion ops in aiter test
            use_aiter_quant=use_aiter_quant_op,  # Toggle aiter quantization
432
433
434
435
436
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )