test_fusion.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.config
9
import vllm.plugins
10
11
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
12
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
13
14
15
16
17
18
19
20
21
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
    FUSED_OPS,
    FusedRMSQuantKey,
    RMSNormQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
22
23
24
25
26
27
28
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
29
30
from vllm.model_executor.kernels.linear import (
    ChannelWiseTorchFP8ScaledMMLinearKernel,
31
32
    CutlassFP8ScaledMMLinearKernel,
    FlashInferFP8ScaledMMLinearKernel,
33
    FP8ScaledMMLinearKernel,
34
35
    PerTensorTorchFP8ScaledMMLinearKernel,
    ROCmFP8ScaledMMLinearKernel,
36
    RowWiseTorchFP8ScaledMMLinearKernel,
37
)
38
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.quantization.utils.quant_utils import (
40
41
42
43
    GroupShape,
    QuantKey,
    ScaleDesc,
)
44
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
45
    cutlass_block_fp8_supported,
46
)
47
from vllm.platforms import current_platform
48
49
50
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
51

52
53
FP8_DTYPE = current_platform.fp8_dtype()

54
55
56
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
    (FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
    (PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # ROCmFP8ScaledMMLinearKernel supports per-tensor only
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
    # RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
    # Blockwise group shapes (no kernel abstraction)
    (None, GroupShape(1, 128)),
    (None, GroupShape(1, 64)),
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
    CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
    if current_platform.is_cuda()
    else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)

# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
    # Per-token with ROCmFP8ScaledMMLinearKernel
    (ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
    # Per-token with RowWiseTorchFP8ScaledMMLinearKernel
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
    (ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
    # Blockwise (no kernel abstraction)
    (None, GroupShape(1, 128), True),
]

107
108

class TestModel(torch.nn.Module):
109
110
111
112
    def __init__(
        self,
        hidden_size: int,
        eps: float,
113
        force_kernel: FP8ScaledMMLinearKernel | None,
114
        group_shape: GroupShape,
115
116
        use_aiter_fusion: bool = False,
        use_aiter_quant: bool = False,
117
118
119
        *args,
        **kwargs,
    ):
120
        super().__init__(*args, **kwargs)
121
        self.fp8_linear_layers: list[torch.nn.Module]
122
        self.group_shape = group_shape
123
124
        self.use_aiter_quant_op = use_aiter_quant
        self.use_aiter_fusion = use_aiter_fusion
125
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
126
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
127

128
129
        # Determine if blockwise based on group_shape
        is_blockwise = group_shape.is_per_group()
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if is_blockwise:
            act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
            )
            self.fp8_linear_layers = [
                TestBlockFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    group_shape=group_shape,
                    cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                    use_aiter_and_is_supported=use_aiter_quant,
                    transpose_weights=use_aiter_fusion,
                )
                for _ in range(3)
            ]
146

147
148
149
150
            self.enable_quant_fp8_custom_op = (
                False
                if use_aiter_quant
                else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
151
            )
152

153
        else:
154
155
156
157
158
            is_static = group_shape == GroupShape.PER_TENSOR
            act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
            w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
            self.activation_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
159
            )
160
161
            self.weight_quant_key = QuantKey(
                dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
162
            )
163
164
165
166
167
168
            self.fp8_linear_layers = [
                TestFP8Layer(
                    weight_shape=(hidden_size, hidden_size),
                    activation_quant_key=self.activation_quant_key,
                    weight_quant_key=self.weight_quant_key,
                    force_kernel=force_kernel,
169
                )
170
171
                for _ in range(3)
            ]
172

173
174
175
176
177
178
179
            # Enable aiter quantization if requested
            for layer in self.fp8_linear_layers:
                layer.kernel.quant_fp8.use_aiter = use_aiter_quant

            self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
                0
            ].is_quant_fp8_enabled()
180

181
    def forward(self, x):
182
183
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
184
185
        y = self.norm[0](x)

186
        x2 = self.fp8_linear_layers[0](y)
187
188
189
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

190
        x3 = self.fp8_linear_layers[1](y2)
191

192
193
        y3, resid = self.norm[2](x3, resid)  # use resid here

194
        x4 = self.fp8_linear_layers[2](y3)
195
196
197

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
198

199
    def ops_in_model_before(self):
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if self.group_shape.is_per_group():
            # Blockwise path
            if self.use_aiter_fusion and self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_group_quant_op()]
            if self.use_aiter_fusion:
                return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        else:
            if self.use_aiter_quant_op:
                return [rocm_aiter_ops.get_per_token_quant_op()]

        # Common path
        return (
            [QUANT_OPS[self.activation_quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )
216

217
    def ops_in_model_after(self):
218
219
220
        if self.use_aiter_fusion:
            if self.group_shape.is_per_group():
                # Blockwise aiter fusion
221
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
222
223
224
                    AiterFusedAddRMSFp8GroupQuantPattern,
                    AiterRMSFp8GroupQuantPattern,
                )
225

226
227
228
229
230
231
                return [
                    AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                    AiterRMSFp8GroupQuantPattern.FUSED_OP,
                ]
            else:
                # Per-token aiter fusion
232
                from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
233
234
235
                    AiterFusedAddRMSNormDynamicQuantPattern,
                    AiterRMSNormDynamicQuantPattern,
                )
236

237
238
239
240
241
242
                return [
                    AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                    AiterRMSNormDynamicQuantPattern.FUSED_OP,
                ]

        # Regular fusion
243
        return [
244
245
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
246
247
        ]

248
249
250
251
252
253
254
    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

255

256
257
258
259
260
261
262
263
264
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
265

266
267
268
269
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
270

271
272
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
273

274
275
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
276

277
278
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
279

280
281
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
282

283
284
285
286
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
287

288
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
289

290
291
292
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
293

294
    return backend, backend2
295
296


297
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
298
@pytest.mark.parametrize("hidden_size", [256])
299
@pytest.mark.parametrize("num_tokens", [257])
300
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
301
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
302
303
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
304
305
306
307
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
308
309
310
311
    dtype,
    hidden_size,
    num_tokens,
    eps,
312
    kernel_groupshape,
313
314
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
315
):
316
317
    force_kernel, group_shape = kernel_groupshape

318
319
320
321
322
323
324
325
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

326
327
328
329
330
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
331

332
    vllm_config = VllmConfig(
333
        model_config=ModelConfig(dtype=dtype),
334
        compilation_config=CompilationConfig(
335
            mode=CompilationMode.VLLM_COMPILE,
336
            custom_ops=custom_ops,
337
338
339
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
340
        ),
341
    )
342

343
344
345
346
347
348
349
    with vllm.config.set_current_vllm_config(vllm_config):
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
350

351
        model = TestModel(
352
353
            hidden_size=hidden_size,
            eps=eps,
354
            force_kernel=force_kernel,
355
            group_shape=group_shape,
356
357
            use_aiter_fusion=False,
            use_aiter_quant=False,
358
        )
359

360
361
362
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
363
364
365
366
367
368
369
370
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
371
        if not enable_rms_norm_custom_op:
372
373
374
375
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2
376
377
378
379
380
381
382


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
383
    "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
384
385
386
387
388
389
390
391
392
393
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
394
    kernel_groupshape_quant: tuple,
395
396
    monkeypatch: pytest.MonkeyPatch,
):
397
    force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
398
399
400
401
402
403
404
405
406
407
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
408
409
410
        from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
            RocmAiterRMSNormQuantFusionPass,
        )
411
412

        m.setenv("VLLM_ROCM_USE_AITER", "1")
413

414
415
416
417
418
419
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)

420
        fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
421

422
423
424
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
425
            force_kernel=force_kernel,
426
            group_shape=group_shape,
427
428
            use_aiter_fusion=True,  # Always use aiter fusion ops in aiter test
            use_aiter_quant=use_aiter_quant_op,  # Toggle aiter quantization
429
430
431
432
433
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )