test_fusion.py 8.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import pytest
import torch

7
import vllm.plugins
8
9
10
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
11
from vllm.compilation.noop_elimination import NoOpEliminationPass
12
from vllm.compilation.post_cleanup import PostCleanupPass
13
14
15
16
17
18
19
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
22
23
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
)
24
from vllm.model_executor.layers.quantization.utils.quant_utils import (
25
26
27
28
    GroupShape,
    QuantKey,
    ScaleDesc,
)
29
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
30
    Fp8LinearOp,
31
    cutlass_block_fp8_supported,
32
33
34
    cutlass_fp8_supported,
    maybe_create_device_identity,
)
35
from vllm.platforms import current_platform
36
from vllm.utils.deep_gemm import is_deep_gemm_supported
37

38
from ..utils import override_cutlass_fp8_supported
39
40
from .backend import TestBackend

41
42
FP8_DTYPE = current_platform.fp8_dtype()

43
44
45
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

46
47

class TestModel(torch.nn.Module):
48
49
50
51
    def __init__(
        self,
        hidden_size: int,
        eps: float,
52
        group_shape: GroupShape,
53
54
55
56
        cuda_force_torch: bool,
        *args,
        **kwargs,
    ):
57
        super().__init__(*args, **kwargs)
58
        self.cuda_force_torch = cuda_force_torch
59
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
60
61
62
63
64
65
66
67
68
69
70
        if group_shape.is_per_group():
            self.wscale = [
                torch.rand(
                    (hidden_size // group_shape[1], hidden_size // group_shape[1]),
                    dtype=torch.float32,
                )
                for _ in range(3)
            ]
        else:
            self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        static = group_shape == GroupShape.PER_TENSOR
71
        quant_scale = ScaleDesc(torch.float32, static, group_shape)
72
        self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
73
        if static:
74
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
75
        else:
76
            self.scale = [None for _ in range(3)]
77
        self.w = [
78
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
79
        ]
80
81
        if not group_shape.is_per_group():
            self.w = [self.w[0].t() for _ in range(3)]
82

83
84
85
        if group_shape.is_per_group():
            self.fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
86
                act_quant_group_shape=group_shape,
87
88
                cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
                use_aiter_and_is_supported=False,
89
            )
90
91
92
93
94
95
96
97
            self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
        else:
            with override_cutlass_fp8_supported(not cuda_force_torch):
                self.fp8_linear = Fp8LinearOp(
                    act_quant_static=static,
                    act_quant_group_shape=group_shape,
                )
                self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
98

99
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
100
        self.group_shape = group_shape
101

102
    def forward(self, x):
103
104
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
105
106
        y = self.norm[0](x)

107
108
109
        x2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
        )
110
111
112
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

113
114
115
        x3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
        )
116

117
118
        y3, resid = self.norm[2](x3, resid)  # use resid here

119
120
121
122
123
124
        x4 = self.fp8_linear.apply(
            y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
        )

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
125
126
127

    def ops_in_model_after(self):
        return [
128
129
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
130
131
        ]

132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def ops_in_model_before(self):
        return (
            [QUANT_OPS[self.quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )

    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

146

147
148
149
150
151
152
153
154
GROUP_SHAPES = [
    GroupShape.PER_TOKEN,
    GroupShape.PER_TENSOR,
    GroupShape(1, 128),
    GroupShape(1, 64),
]


155
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
156
@pytest.mark.parametrize("hidden_size", [256])
157
@pytest.mark.parametrize("num_tokens", [257])
158
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
159
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
160
161
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
162
163
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
164
165
166
167
168
169
170
@pytest.mark.parametrize(
    "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
171
172
173
174
    dtype,
    hidden_size,
    num_tokens,
    eps,
175
    group_shape,
176
177
178
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
    cuda_force_torch,
179
):
180
    torch.set_default_device("cuda")
181
182
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
183
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
184

185
186
187
188
189
190
191
192
193
    if not enable_quant_fp8_custom_op and group_shape.is_per_group():
        pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")

    # Skip test for 64-bit group shape when running with cutlass or deepgemm
    if group_shape == GroupShape(1, 64) and (
        cutlass_block_fp8_supported() or is_deep_gemm_supported()
    ):
        pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")

194
195
196
197
198
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
199
    vllm_config = VllmConfig(
200
        model_config=ModelConfig(dtype=dtype),
201
        compilation_config=CompilationConfig(
202
            mode=CompilationMode.VLLM_COMPILE,
203
            custom_ops=custom_ops,
204
205
206
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
207
        ),
208
    )
209
210
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
211
        noop_pass = NoOpEliminationPass(vllm_config)
212
213
        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
214

215
        backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
216
        backend2 = TestBackend(noop_pass, cleanup_pass)
217
        model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
218
219
220
221
        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

222
223
        model_fused = torch.compile(model, backend=backend)
        result_fused = model_fused(x)
224

225
226
        model_unfused = torch.compile(model, backend=backend2)
        result_unfused = model_unfused(x)
227

228
        if dtype == torch.float16:
229
230
231
232
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

233
        torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
234

235
        assert fusion_pass.matched_count == 3
236
        backend.check_before_ops(model.ops_in_model_before())
237
238
239
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )
240
        backend.check_after_ops(model.ops_in_model_after())
241
242
243
244
245
246
247
248
249
250

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
        if not enable_rms_norm_custom_op:
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2