test_fusion.py 15.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.config
9
import vllm.plugins
10
11
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
12
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
13
14
15
16
17
18
19
20
21
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
    FUSED_OPS,
    FusedRMSQuantKey,
    RMSNormQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
22
23
24
25
26
27
28
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
29
from vllm.model_executor.layers.layernorm import RMSNorm
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
    CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
    FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
    ChannelWiseTorchFP8ScaledMMLinearKernel,
    PerTensorTorchFP8ScaledMMLinearKernel,
    RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
    ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (  # noqa: E501
    FP8ScaledMMLinearKernel,
46
)
47
from vllm.model_executor.layers.quantization.utils.quant_utils import (
48
49
50
51
    GroupShape,
    QuantKey,
    ScaleDesc,
)
52
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
53
    cutlass_block_fp8_supported,
54
)
55
from vllm.platforms import current_platform
56
57
58
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
59

60
61
FP8_DTYPE = current_platform.fp8_dtype()

62
63
64
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
    (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
    (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # ROCmFP8ScaledMMLinearKernel supports per-tensor only
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
    CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
    if current_platform.is_cuda()
    else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)

# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # Per-token with ROCmFP8ScaledMMLinearKernel
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
    # Per-token with RowWiseTorchFP8ScaledMMLinearKernel
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Blockwise (no kernel abstraction)
    (None, GroupShape(1, 128), True),
]

115
116

class TestModel(torch.nn.Module):
117
118
119
120
    def __init__(
        self,
        hidden_size: int,
        eps: float,
121
        force_kernel: FP8ScaledMMLinearKernel | None,
122
        group_shape: GroupShape,
123
124
        use_aiter_fusion: bool = False,
        use_aiter_quant: bool = False,
125
126
127
        *args,
        **kwargs,
    ):
128
        super().__init__(*args, **kwargs)
129
        self.fp8_linear_layers: list[torch.nn.Module]
130
        self.group_shape = group_shape
131
132
        self.use_aiter_quant_op = use_aiter_quant
        self.use_aiter_fusion = use_aiter_fusion
133
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
134
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
135

136
137
        # Determine if blockwise based on group_shape
        is_blockwise = group_shape.is_per_group()
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if is_blockwise:
            act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
            )
            self.fp8_linear_layers = [
                TestBlockFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    group_shape=group_shape,
                    cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                    use_aiter_and_is_supported=use_aiter_quant,
                    transpose_weights=use_aiter_fusion,
                )
                for _ in range(3)
            ]
154

155
156
157
158
            self.enable_quant_fp8_custom_op = (
                False
                if use_aiter_quant
                else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
159
            )
160

161
        else:
162
163
164
165
166
            is_static = group_shape == GroupShape.PER_TENSOR
            act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
            w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
167
            )
168
169
            self.weight_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
170
            )
171
172
173
174
175
176
            self.fp8_linear_layers = [
                TestFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    activation_quant_key=self.activation_quant_key,
                    weight_quant_key=self.weight_quant_key,
                    force_kernel=force_kernel,
177
                )
178
179
                for _ in range(3)
            ]
180

181
182
183
184
185
186
187
            # Enable aiter quantization if requested
            for layer in self.fp8_linear_layers:
                layer.kernel.quant_fp8.use_aiter = use_aiter_quant

            self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
                0
            ].is_quant_fp8_enabled()
188

189
    def forward(self, x):
190
191
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
192
193
        y = self.norm[0](x)

194
        x2 = self.fp8_linear_layers[0](y)
195
196
197
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

198
        x3 = self.fp8_linear_layers[1](y2)
199

200
201
        y3, resid = self.norm[2](x3, resid)  # use resid here

202
        x4 = self.fp8_linear_layers[2](y3)
203
204
205

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
206

207
    def ops_in_model_before(self):
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        if self.group_shape.is_per_group():
            # Blockwise path
            if self.use_aiter_fusion and self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_group_quant_op()]
            if self.use_aiter_fusion:
                return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        else:
            if self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_per_token_quant_op()]

        # Common path
        return (
            [QUANT_OPS[self.activation_quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )
224

225
    def ops_in_model_after(self):
226
227
228
        if self.use_aiter_fusion:
            if self.group_shape.is_per_group():
                # Blockwise aiter fusion
229
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
230
231
232
                    AiterFusedAddRMSFp8GroupQuantPattern,
                    AiterRMSFp8GroupQuantPattern,
                )
233

234
235
236
237
238
239
                return [
                    AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                    AiterRMSFp8GroupQuantPattern.FUSED_OP,
                ]
            else:
                # Per-token aiter fusion
240
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
241
242
243
                    AiterFusedAddRMSNormDynamicQuantPattern,
                    AiterRMSNormDynamicQuantPattern,
                )
244

245
246
247
248
249
250
                return [
                    AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                    AiterRMSNormDynamicQuantPattern.FUSED_OP,
                ]

        # Regular fusion
251
        return [
252
253
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
254
255
        ]

256
257
258
259
260
261
262
    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

263

264
265
266
267
268
269
270
271
272
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
273

274
275
276
277
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
278

279
280
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
281

282
283
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
284

285
286
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
287

288
289
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
290

291
292
293
294
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
295

296
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
297

298
299
300
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
301

302
    return backend, backend2
303
304


305
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
306
@pytest.mark.parametrize("hidden_size", [256])
307
@pytest.mark.parametrize("num_tokens", [257])
308
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
309
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
310
311
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
312
313
314
315
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
316
317
318
319
    dtype,
    hidden_size,
    num_tokens,
    eps,
320
    kernel_groupshape,
321
322
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
323
):
324
325
    force_kernel, group_shape = kernel_groupshape

326
327
328
329
330
331
332
333
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

334
335
336
337
338
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
339

340
    vllm_config = VllmConfig(
341
        model_config=ModelConfig(dtype=dtype),
342
        compilation_config=CompilationConfig(
343
            mode=CompilationMode.VLLM_COMPILE,
344
            custom_ops=custom_ops,
345
346
347
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
348
        ),
349
    )
350

351
352
353
354
355
356
357
    with vllm.config.set_current_vllm_config(vllm_config):
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
358

359
        model = TestModel(
360
361
            hidden_size=hidden_size,
            eps=eps,
362
            force_kernel=force_kernel,
363
            group_shape=group_shape,
364
365
            use_aiter_fusion=False,
            use_aiter_quant=False,
366
        )
367

368
369
370
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
371
372
373
374
375
376
377
378
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
379
        if not enable_rms_norm_custom_op:
380
381
382
383
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2
384
385
386
387
388
389
390


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
391
    "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
392
393
394
395
396
397
398
399
400
401
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
402
    kernel_groupshape_quant: tuple,
403
404
    monkeypatch: pytest.MonkeyPatch,
):
405
    force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
406
407
408
409
410
411
412
413
414
415
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
416
417
418
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
            RocmAiterRMSNormQuantFusionPass,
        )
419
420

        m.setenv("VLLM_ROCM_USE_AITER", "1")
421

422
423
424
425
426
427
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

428
        fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
429

430
431
432
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
433
            force_kernel=force_kernel,
434
            group_shape=group_shape,
435
436
            use_aiter_fusion=True,  # Always use aiter fusion ops in aiter test
            use_aiter_quant=use_aiter_quant_op,  # Toggle aiter quantization
437
438
439
440
441
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )