test_fusion.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import pytest
import torch

8
import vllm.plugins
9
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
10
11
12
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
13
from vllm.compilation.noop_elimination import NoOpEliminationPass
14
from vllm.compilation.post_cleanup import PostCleanupPass
15
16
17
18
19
20
21
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
22
from vllm.model_executor.layers.layernorm import RMSNorm
23
24
25
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
)
26
from vllm.model_executor.layers.quantization.utils.quant_utils import (
27
28
29
30
    GroupShape,
    QuantKey,
    ScaleDesc,
)
31
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
32
    Fp8LinearOp,
33
    cutlass_block_fp8_supported,
34
35
36
    cutlass_fp8_supported,
    maybe_create_device_identity,
)
37
from vllm.platforms import current_platform
38
from vllm.utils.deep_gemm import is_deep_gemm_supported
39

40
from ..utils import override_cutlass_fp8_supported
41
42
from .backend import TestBackend

43
44
FP8_DTYPE = current_platform.fp8_dtype()

45
46
47
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

48
49

class TestModel(torch.nn.Module):
50
51
52
53
    def __init__(
        self,
        hidden_size: int,
        eps: float,
54
        group_shape: GroupShape,
55
56
57
        use_aiter: bool = False,
        cuda_force_torch: bool = False,
        use_aiter_quant_op: bool = True,
58
59
60
        *args,
        **kwargs,
    ):
61
        super().__init__(*args, **kwargs)
62
63
        self.use_aiter = use_aiter
        self.use_aiter_quant_op = use_aiter_quant_op
64
        self.cuda_force_torch = cuda_force_torch
65
66
67
        self.group_shape = group_shape
        self.enable_quant_fp8_custom_op = None  # Will be set later if applicable

68
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
69
70
71

        # Setup quantization scale descriptor
        static = group_shape == GroupShape.PER_TENSOR and not use_aiter
72
        quant_scale = ScaleDesc(torch.float32, static, group_shape)
73
        self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
74
75

        # Setup scales
76
        if static:
77
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
78
        else:
79
            self.scale = [None for _ in range(3)]
80
81

        # Setup weights
82
        self.w = [
83
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
84
        ]
85
        if not group_shape.is_per_group() or use_aiter:
86
            self.w = [self.w[0].t() for _ in range(3)]
87

88
        # Setup weight scales
89
        if group_shape.is_per_group():
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            scale_size = (
                (hidden_size + 128 - 1) // 128
                if use_aiter
                else hidden_size // group_shape[1]
            )
            wscale_shape: tuple[int, ...] = (scale_size, scale_size)
        else:
            wscale_shape = (1,)
        self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]

        # Setup FP8 linear operation
        is_per_group = group_shape.is_per_group()
        if is_per_group and use_aiter:
            self.fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(128, 128),
                act_quant_group_shape=group_shape,
                use_aiter_and_is_supported=use_aiter_quant_op,
            )
            # AITER blockwise doesn't use enable_quant_fp8_custom_op
        elif is_per_group:
110
111
            self.fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
112
                act_quant_group_shape=group_shape,
113
114
                cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                use_aiter_and_is_supported=False,
115
            )
116
            self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
117
118
119
120
121
122
123
        elif use_aiter:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=False,
                act_quant_group_shape=group_shape,
            )
            self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
            self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
124
125
126
127
128
129
130
        else:
            with override_cutlass_fp8_supported(not cuda_force_torch):
                self.fp8_linear = Fp8LinearOp(
                    act_quant_static=static,
                    act_quant_group_shape=group_shape,
                )
                self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
131

132
133
        self.enable_rms_norm_custom_op = self.norm[0].enabled()

134
    def forward(self, x):
135
136
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
137
138
        y = self.norm[0](x)

139
140
141
        x2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
        )
142
143
144
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

145
146
147
        x3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
        )
148

149
150
        y3, resid = self.norm[2](x3, resid)  # use resid here

151
152
153
154
155
156
        x4 = self.fp8_linear.apply(
            y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
        )

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    def ops_in_model_before(self):
        if (
            self.use_aiter
            and self.group_shape.is_per_group()
            and current_platform.is_fp8_fnuz()
        ):
            return [rocm_aiter_ops.get_group_quant_op()]
        if self.use_aiter and self.group_shape.is_per_group():
            return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
        if self.use_aiter and self.use_aiter_quant_op:
            return [rocm_aiter_ops.get_per_token_quant_op()]
        if self.use_aiter:
            return [QUANT_OPS[self.quant_key]]
        if self.enable_quant_fp8_custom_op:
            return [QUANT_OPS[self.quant_key]]
        return [torch.ops.aten.reciprocal]

175
    def ops_in_model_after(self):
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        if self.use_aiter and self.group_shape.is_per_group():
            from vllm.compilation.rocm_aiter_fusion import (
                AiterFusedAddRMSFp8GroupQuantPattern,
                AiterRMSFp8GroupQuantPattern,
            )

            return [
                AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
                AiterRMSFp8GroupQuantPattern.FUSED_OP,
            ]
        if self.use_aiter:
            from vllm.compilation.rocm_aiter_fusion import (
                AiterFusedAddRMSNormDynamicQuantPattern,
                AiterRMSNormDynamicQuantPattern,
            )

            return [
                AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
                AiterRMSNormDynamicQuantPattern.FUSED_OP,
            ]
196
        return [
197
198
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
199
200
        ]

201
202
203
204
205
206
207
    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

208

209
210
211
212
213
214
215
216
GROUP_SHAPES = [
    GroupShape.PER_TOKEN,
    GroupShape.PER_TENSOR,
    GroupShape(1, 128),
    GroupShape(1, 64),
]


217
218
219
220
221
222
223
224
225
def _run_fusion_test(
    model,
    fusion_pass,
    vllm_config,
    dtype,
    hidden_size,
    num_tokens,
):
    """Helper function for common fusion test logic.
226

227
228
229
230
    Must be called within vllm_config context.
    """
    noop_pass = NoOpEliminationPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
231

232
233
    backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
    backend2 = TestBackend(noop_pass, cleanup_pass)
234

235
236
    x = torch.rand(num_tokens, hidden_size)
    torch._dynamo.mark_dynamic(x, 0)
237

238
239
    model_fused = torch.compile(model, backend=backend)
    result_fused = model_fused(x)
240

241
242
    model_unfused = torch.compile(model, backend=backend2)
    result_unfused = model_unfused(x)
243

244
245
246
247
    if dtype == torch.float16:
        ATOL, RTOL = (2e-3, 2e-3)
    else:
        ATOL, RTOL = (1e-2, 1e-2)
248

249
    torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
250

251
252
253
    assert fusion_pass.matched_count == 3
    backend.check_before_ops(model.ops_in_model_before())
    backend.check_after_ops(model.ops_in_model_after())
254

255
    return backend, backend2
256
257


258
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
259
@pytest.mark.parametrize("hidden_size", [256])
260
@pytest.mark.parametrize("num_tokens", [257])
261
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
262
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
263
264
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
265
266
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
267
268
269
270
271
272
273
@pytest.mark.parametrize(
    "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
274
275
276
277
    dtype,
    hidden_size,
    num_tokens,
    eps,
278
    group_shape,
279
280
281
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
    cuda_force_torch,
282
):
283
284
285
286
287
288
289
290
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

291
292
293
294
295
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
296

297
    vllm_config = VllmConfig(
298
        model_config=ModelConfig(dtype=dtype),
299
        compilation_config=CompilationConfig(
300
            mode=CompilationMode.VLLM_COMPILE,
301
            custom_ops=custom_ops,
302
303
304
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
305
        ),
306
    )
307

308
309
310
311
312
313
314
315
316
    with vllm.config.set_current_vllm_config(vllm_config):
        # Setup device before model creation
        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)
        maybe_create_device_identity()

        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        model = TestModel(
317
318
319
            hidden_size=hidden_size,
            eps=eps,
            group_shape=group_shape,
320
            use_aiter=False,
321
322
            cuda_force_torch=cuda_force_torch,
        )
323

324
325
326
        backend, _ = _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )
327
328
329
330
331
332
333
334
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
335
        if not enable_rms_norm_custom_op:
336
337
338
339
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400


GROUP_SHAPE_QUANT_OPS_MATCHS = [
    (GroupShape.PER_TOKEN, True),
    (GroupShape.PER_TOKEN, False),
    (GroupShape(1, 128), True),
]


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
    "group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
)
@pytest.mark.skipif(
    (not current_platform.is_rocm() or not IS_AITER_FOUND),
    reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
    dtype: torch.dtype,
    hidden_size: int,
    num_tokens: int,
    eps: float,
    group_shape: GroupShape,
    use_aiter_quant_op: bool,
    monkeypatch: pytest.MonkeyPatch,
):
    vllm_config = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rms_norm", "+quant_fp8"],
            pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
        ),
    )

    with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
        from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass

        m.setenv("VLLM_ROCM_USE_AITER", "1")
        rocm_aiter_ops.refresh_env_variables()

        torch.set_default_device("cuda")
        torch.set_default_dtype(dtype)
        torch.manual_seed(1)
        maybe_create_device_identity()

        fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
        model = TestModel(
            hidden_size=hidden_size,
            eps=eps,
            group_shape=group_shape,
            use_aiter=True,
            use_aiter_quant_op=use_aiter_quant_op,
        )

        _run_fusion_test(
            model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
        )