test_fusion.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import itertools

6
7
8
import pytest
import torch

9
import vllm.plugins
10
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
11
12
13
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
14
from vllm.compilation.noop_elimination import NoOpEliminationPass
15
from vllm.compilation.post_cleanup import PostCleanupPass
16
17
18
19
20
21
22
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
)
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
28
29
30
31
    GroupShape,
    QuantKey,
    ScaleDesc,
)
32
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
33
    Fp8LinearOp,
34
    cutlass_block_fp8_supported,
35
36
37
    cutlass_fp8_supported,
    maybe_create_device_identity,
)
38
from vllm.platforms import current_platform
39
from vllm.utils.deep_gemm import is_deep_gemm_supported
40

41
from ..utils import override_cutlass_fp8_supported
42
43
from .backend import TestBackend

44
45
FP8_DTYPE = current_platform.fp8_dtype()

46
47
48
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

49
50

class TestModel(torch.nn.Module):
51
52
53
54
    def __init__(
        self,
        hidden_size: int,
        eps: float,
55
        group_shape: GroupShape,
56
57
58
59
        cuda_force_torch: bool,
        *args,
        **kwargs,
    ):
60
        super().__init__(*args, **kwargs)
61
        self.cuda_force_torch = cuda_force_torch
62
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
63
64
65
66
67
68
69
70
71
72
73
        if group_shape.is_per_group():
            self.wscale = [
                torch.rand(
                    (hidden_size // group_shape[1], hidden_size // group_shape[1]),
                    dtype=torch.float32,
                )
                for _ in range(3)
            ]
        else:
            self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        static = group_shape == GroupShape.PER_TENSOR
74
        quant_scale = ScaleDesc(torch.float32, static, group_shape)
75
        self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
76
        if static:
77
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
78
        else:
79
            self.scale = [None for _ in range(3)]
80
        self.w = [
81
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
82
        ]
83
84
        if not group_shape.is_per_group():
            self.w = [self.w[0].t() for _ in range(3)]
85

86
87
88
        if group_shape.is_per_group():
            self.fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
89
                act_quant_group_shape=group_shape,
90
91
                cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                use_aiter_and_is_supported=False,
92
            )
93
94
95
96
97
98
99
100
            self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
        else:
            with override_cutlass_fp8_supported(not cuda_force_torch):
                self.fp8_linear = Fp8LinearOp(
                    act_quant_static=static,
                    act_quant_group_shape=group_shape,
                )
                self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
101

102
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
103
        self.group_shape = group_shape
104

105
    def forward(self, x):
106
107
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
108
109
        y = self.norm[0](x)

110
111
112
        x2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
        )
113
114
115
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

116
117
118
        x3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
        )
119

120
121
        y3, resid = self.norm[2](x3, resid)  # use resid here

122
123
124
125
126
127
        x4 = self.fp8_linear.apply(
            y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
        )

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
128
129
130

    def ops_in_model_after(self):
        return [
131
132
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
133
134
        ]

135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def ops_in_model_before(self):
        return (
            [QUANT_OPS[self.quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )

    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

149

150
151
152
153
154
155
156
157
GROUP_SHAPES = [
    GroupShape.PER_TOKEN,
    GroupShape.PER_TENSOR,
    GroupShape(1, 128),
    GroupShape(1, 64),
]


158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
    def __init__(self, hidden_size: int, eps: float, **kwargs):
        super().__init__()
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(128, 128),
            act_quant_group_shape=GroupShape(1, 128),
            cutlass_block_fp8_supported=False,
            use_aiter_and_is_supported=True,
        )
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
            for _ in range(3)
        ]

        scale_hidden_size = (hidden_size + 128 - 1) // 128
        self.wscale = [
            torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
            for _ in range(3)
        ]

        self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
        self.eps = eps

    def forward(self, x):
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
        y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)

        x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
        # make sure resid is used for replacement to work
        y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
            x2, resid, self.norm_weight[1], self.eps
        )

        x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])

        y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
            x3, resid, self.norm_weight[2], self.eps
        )

        x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])

        y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
            x4, resid, self.norm_weight[3], self.eps
        )
        return y4

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.rocm_aiter_rms_norm,
            torch.ops.vllm.rocm_aiter_group_fp8_quant,
        ]

    def ops_in_model_before_partial(self):
        return []

    def ops_in_model_after(self):
        return [
            torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
            torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
        ]


221
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
222
@pytest.mark.parametrize("hidden_size", [256])
223
@pytest.mark.parametrize("num_tokens", [257])
224
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
225
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
226
227
228
229
230
@pytest.mark.parametrize(
    "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
    list(itertools.product([TestModel], [True, False], [True, False]))
    + [(TestRmsnormGroupFp8QuantModel, False, False)],
)
231
232
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
233
234
235
236
237
238
239
@pytest.mark.parametrize(
    "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
240
241
242
243
    dtype,
    hidden_size,
    num_tokens,
    eps,
244
    group_shape,
245
    model_class,
246
247
248
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
    cuda_force_torch,
249
):
250
251
252
    if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
        pytest.skip("AITER is not supported on this GPU.")

253
    torch.set_default_device("cuda")
254
255
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
256
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
257

258
259
260
261
262
263
264
265
266
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    # Skip test for 64-bit group shape when running with cutlass or deepgemm
    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

267
268
269
270
271
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
272
    vllm_config = VllmConfig(
273
        model_config=ModelConfig(dtype=dtype),
274
        compilation_config=CompilationConfig(
275
            mode=CompilationMode.VLLM_COMPILE,
276
            custom_ops=custom_ops,
277
278
279
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
280
        ),
281
    )
282
283
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
284
        noop_pass = NoOpEliminationPass(vllm_config)
285
286
287
288
289
290
291
292
        if model_class is TestRmsnormGroupFp8QuantModel:
            from vllm.compilation.rocm_aiter_fusion import (
                RocmAiterRMSNormFp8GroupQuantFusionPass,
            )

            fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
        else:
            fusion_pass = RMSNormQuantFusionPass(vllm_config)
293
        cleanup_pass = PostCleanupPass(vllm_config)
294

295
        backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
296
        backend2 = TestBackend(noop_pass, cleanup_pass)
297
298
299
300
301
302
        model = model_class(
            hidden_size=hidden_size,
            eps=eps,
            group_shape=group_shape,
            cuda_force_torch=cuda_force_torch,
        )
303
304
305
306
        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

307
308
        model_fused = torch.compile(model, backend=backend)
        result_fused = model_fused(x)
309

310
311
        model_unfused = torch.compile(model, backend=backend2)
        result_unfused = model_unfused(x)
312

313
        if dtype == torch.float16:
314
315
316
317
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

318
        torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
319

320
        assert fusion_pass.matched_count == 3
321
        backend.check_before_ops(model.ops_in_model_before())
322
323
324
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )
325
        backend.check_after_ops(model.ops_in_model_after())
326
327
328
329
330

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
331
332
333
334
        if (
            not enable_rms_norm_custom_op
            and model_class is not TestRmsnormGroupFp8QuantModel
        ):
335
336
337
338
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2