test_fusion.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.config
9
import vllm.ir.ops
10
import vllm.plugins
11
12
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
13
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
14
15
16
17
18
19
20
21
22
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
    FUSED_OPS,
    FusedRMSQuantKey,
    RMSNormQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
23
24
25
26
27
28
29
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
30
31
from vllm.model_executor.kernels.linear import (
    ChannelWiseTorchFP8ScaledMMLinearKernel,
32
33
    CutlassFP8ScaledMMLinearKernel,
    FlashInferFP8ScaledMMLinearKernel,
34
    FP8ScaledMMLinearKernel,
35
36
    PerTensorTorchFP8ScaledMMLinearKernel,
    ROCmFP8ScaledMMLinearKernel,
37
    RowWiseTorchFP8ScaledMMLinearKernel,
38
)
39
from vllm.model_executor.layers.layernorm import RMSNorm
40
from vllm.model_executor.layers.quantization.utils.quant_utils import (
41
42
43
44
    GroupShape,
    QuantKey,
    ScaleDesc,
)
45
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
46
    cutlass_block_fp8_supported,
47
)
48
from vllm.platforms import current_platform
49
50
51
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
52

53
54
FP8_DTYPE = current_platform.fp8_dtype()

55
56
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
    (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
    (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # ROCmFP8ScaledMMLinearKernel supports per-tensor only
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
    CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
    if current_platform.is_cuda()
    else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)

# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # Per-token with ROCmFP8ScaledMMLinearKernel
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
    # Per-token with RowWiseTorchFP8ScaledMMLinearKernel
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Blockwise (no kernel abstraction)
    (None, GroupShape(1, 128), True),
]

107
108

class TestModel(torch.nn.Module):
109
110
111
112
    def __init__(
        self,
        hidden_size: int,
        eps: float,
113
        force_kernel: FP8ScaledMMLinearKernel | None,
114
        group_shape: GroupShape,
115
116
        use_aiter_fusion: bool = False,
        use_aiter_quant: bool = False,
117
118
119
        *args,
        **kwargs,
    ):
120
        super().__init__(*args, **kwargs)
121
        self.fp8_linear_layers: list[torch.nn.Module]
122
        self.group_shape = group_shape
123
124
        self.use_aiter_quant_op = use_aiter_quant
        self.use_aiter_fusion = use_aiter_fusion
125
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
126
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
127

128
129
        # Determine if blockwise based on group_shape
        is_blockwise = group_shape.is_per_group()
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if is_blockwise:
            act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
            )
            self.fp8_linear_layers = [
                TestBlockFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    group_shape=group_shape,
                    cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                    use_aiter_and_is_supported=use_aiter_quant,
                    transpose_weights=use_aiter_fusion,
                )
                for _ in range(3)
            ]
146

147
148
149
150
            self.enable_quant_fp8_custom_op = (
                False
                if use_aiter_quant
                else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
151
            )
152

153
        else:
154
155
156
157
158
            is_static = group_shape == GroupShape.PER_TENSOR
            act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
            w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
159
            )
160
161
            self.weight_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
162
            )
163
164
165
166
167
168
            self.fp8_linear_layers = [
                TestFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    activation_quant_key=self.activation_quant_key,
                    weight_quant_key=self.weight_quant_key,
                    force_kernel=force_kernel,
169
                )
170
171
                for _ in range(3)
            ]
172

173
174
175
176
177
178
179
            # Enable aiter quantization if requested
            for layer in self.fp8_linear_layers:
                layer.kernel.quant_fp8.use_aiter = use_aiter_quant

            self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
                0
            ].is_quant_fp8_enabled()
180

181
    def forward(self, x):
182
183
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
184
185
        y = self.norm[0](x)

186
        x2 = self.fp8_linear_layers[0](y)
187
188
189
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

190
        x3 = self.fp8_linear_layers[1](y2)
191

192
193
        y3, resid = self.norm[2](x3, resid)  # use resid here

194
        x4 = self.fp8_linear_layers[2](y3)
195
196
197

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
198

199
    def ops_in_model_before(self):
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if self.group_shape.is_per_group():
            # Blockwise path
            if self.use_aiter_fusion and self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_group_quant_op()]
            if self.use_aiter_fusion:
                return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        else:
            if self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_per_token_quant_op()]

        # Common path
        return (
            [QUANT_OPS[self.activation_quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )
216

217
    def ops_in_model_after(self):
218
219
220
        if self.use_aiter_fusion:
            if self.group_shape.is_per_group():
                # Blockwise aiter fusion
221
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
222
223
224
                    AiterFusedAddRMSFp8GroupQuantPattern,
                    AiterRMSFp8GroupQuantPattern,
                )
225

226
227
228
229
230
231
                return [
                    AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                    AiterRMSFp8GroupQuantPattern.FUSED_OP,
                ]
            else:
                # Per-token aiter fusion
232
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
233
234
235
                    AiterFusedAddRMSNormDynamicQuantPattern,
                    AiterRMSNormDynamicQuantPattern,
                )
236

237
238
239
240
241
242
                return [
                    AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                    AiterRMSNormDynamicQuantPattern.FUSED_OP,
                ]

        # Regular fusion
243
        return [
244
245
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
246
247
        ]

248
    def ops_in_model_before_partial(self):
249
250
        return [torch.ops.vllm_ir.rms_norm] + (
            [RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
251
252
        )

253

254
255
256
257
258
259
260
261
262
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
263

264
265
266
267
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
268

269
270
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
271

272
273
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
274

275
276
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
277

278
279
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
280

281
282
283
284
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
285

286
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
287

288
289
290
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
291

292
    return backend, backend2
293
294


295
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
296
@pytest.mark.parametrize("hidden_size", [256])
297
@pytest.mark.parametrize("num_tokens", [257])
298
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
299
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
300
301
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
302
303
304
305
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
306
307
308
309
    dtype,
    hidden_size,
    num_tokens,
    eps,
310
    kernel_groupshape,
311
312
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
313
):
314
315
    force_kernel, group_shape = kernel_groupshape

316
317
318
319
320
321
322
323
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

324
325
326
327
328
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
329

330
    vllm_config = VllmConfig(
331
        model_config=ModelConfig(dtype=dtype),
332
        compilation_config=CompilationConfig(
333
            mode=CompilationMode.VLLM_COMPILE,
334
            custom_ops=custom_ops,
335
336
337
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
338
        ),
339
    )
340

341
342
343
344
    with (
        vllm.config.set_current_vllm_config(vllm_config),
        vllm_config.kernel_config.ir_op_priority.set_priority(),
    ):
345
346
347
348
349
350
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
351

352
        model = TestModel(
353
354
            hidden_size=hidden_size,
            eps=eps,
355
            force_kernel=force_kernel,
356
            group_shape=group_shape,
357
358
            use_aiter_fusion=False,
            use_aiter_quant=False,
359
        )
360

361
362
363
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
364
365
366
367
368
369
370
371
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
372
        if not enable_rms_norm_custom_op:
373
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
374
375
376
            # rms_norm is IR, not included
            # 6 = 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 6
377
            assert n_add_nodes(backend.graph_post_pass) == 2
378
379
380
381
382
383
384


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
385
    "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
386
387
388
389
390
391
392
393
394
395
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
396
    kernel_groupshape_quant: tuple,
397
398
    monkeypatch: pytest.MonkeyPatch,
):
399
    force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
400
401
402
403
404
405
406
407
408
409
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
410
411
412
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
            RocmAiterRMSNormQuantFusionPass,
        )
413
414

        m.setenv("VLLM_ROCM_USE_AITER", "1")
415

416
417
418
419
420
421
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

422
        fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
423

424
425
426
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
427
            force_kernel=force_kernel,
428
            group_shape=group_shape,
429
430
            use_aiter_fusion=True,  # Always use aiter fusion ops in aiter test
            use_aiter_quant=use_aiter_quant_op,  # Toggle aiter quantization
431
432
433
434
435
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )