test_silu_mul_quant_fusion.py 9.22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import itertools
4

5
6
7
8
import pytest
import torch

import vllm.envs as envs
9
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
10
from vllm._aiter_ops import IS_AITER_FOUND
11
12
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
13
14
15
16
    FUSED_OPS,
    SILU_MUL_OP,
    ActivationQuantFusionPass,
)
17
from vllm.compilation.fusion import QUANT_OPS
18
from vllm.compilation.noop_elimination import NoOpEliminationPass
19
from vllm.compilation.post_cleanup import PostCleanupPass
20
21
22
23
24
25
26
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    PassConfig,
    VllmConfig,
    set_current_vllm_config,
)
27
from vllm.model_executor.layers.activation import SiluAndMul
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
    CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
    FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
    PerTensorTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
    ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (  # noqa: E501
    FP8ScaledMMLinearKernel,
)
43
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
44
from vllm.model_executor.layers.quantization.utils.quant_utils import (
45
46
47
48
    GroupShape,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
49
from vllm.platforms import current_platform
50

51
from ..utils import TestFP8Layer
52
53
from .backend import TestBackend

54
55
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
56
57


58
59
60
61
62
def is_nvfp4_supported():
    return current_platform.has_device_capability(100)


class TestSiluMulFp8QuantModel(torch.nn.Module):
63
64
65
66
67
    quant_key = kFp8StaticTensorSym

    def __init__(
        self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
    ):
68
        super().__init__()
69
70
        self.silu_and_mul = SiluAndMul()

71
72
73
74
75
76
        self.fp8_linear = TestFP8Layer(
            weight_shape=(hidden_size, hidden_size),
            activation_quant_key=self.quant_key,
            weight_quant_key=self.quant_key,
            force_kernel=force_kernel,
        )
77

78
        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
79
        self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
80

81
82
    def forward(self, x):
        y = self.silu_and_mul(x)
83
        x2 = self.fp8_linear(y)
84
85
        return x2

86
    def ops_in_model_before(self):
87
88
89
90
91
92
93
94
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
            (
                QUANT_OPS[kFp8StaticTensorSym]
                if self.enable_quant_fp8_custom_op
                else torch.ops.aten.reciprocal
            ),
        ]
95
96
97
98
99
100

    def ops_in_model_after(self):
        return [FUSED_OPS[kFp8StaticTensorSym]]


class TestSiluMulNvfp4QuantModel(torch.nn.Module):
101
    def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
102
        super().__init__()
103
        from vllm.compilation.activation_quant_fusion import (
104
105
106
            silu_and_mul_nvfp4_quant_supported,
        )

107
108
        assert silu_and_mul_nvfp4_quant_supported

109
        self.silu_and_mul = SiluAndMul()
110
        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
111
112
113
114
115
116
117
118
119

        # create nvfp4 weight
        w = torch.rand((hidden_size, hidden_size))
        self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)

        # get global scale offline
        _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))

        self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
120

121
122
    def forward(self, x):
        y = self.silu_and_mul(x)
123
        y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
124
125
126
127
128
129
130
131
        out = cutlass_scaled_fp4_mm(
            a=y_quant,
            b=self.w,
            block_scale_a=y_block_scale,
            block_scale_b=self.w_block_scale,
            alpha=self.alpha,
            out_dtype=y.dtype,
        )
132
133
134
        return out

    def ops_in_model_before(self):
135
136
137
138
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
            QUANT_OPS[kNvfp4Quant],
        ]
139
140
141
142
143

    def ops_in_model_after(self):
        return [FUSED_OPS[kNvfp4Quant]]


144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
    def __init__(self, hidden_size: int, **kwargs):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(128, 128),
            act_quant_group_shape=GroupShape(1, 128),
            cutlass_block_fp8_supported=False,
            use_aiter_and_is_supported=True,
        )
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()

        scale_hidden_size = (hidden_size + 128 - 1) // 128
        self.wscale = torch.rand(
            (scale_hidden_size, scale_hidden_size), dtype=torch.float32
        )

        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()

    def forward(self, x):
        y = self.silu_and_mul(x)
        x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
        return x2

    def ops_in_model_before(self):
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
        ]

    def ops_in_model_after(self):
        return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]


177
178
179
180
181
182
183
184
185
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
    FlashInferFP8ScaledMMLinearKernel,
    CutlassFP8ScaledMMLinearKernel,
    PerTensorTorchFP8ScaledMMLinearKernel,
]
TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS


186
187
188
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
189
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
190
@pytest.mark.parametrize(
191
192
    "model_class, enable_quant_fp8_custom_op, force_kernel",
    list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
193
    + [
194
195
        (TestSiluMulNvfp4QuantModel, False, None),
        (TestSiluMulGroupFp8QuantModel, False, None),
196
    ],
197
198
199
200
201
)
@pytest.mark.skipif(
    envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
202
203
204
    num_tokens: int,
    hidden_size: int,
    dtype: torch.dtype,
205
206
207
208
209
    model_class: type[
        TestSiluMulFp8QuantModel
        | TestSiluMulNvfp4QuantModel
        | TestSiluMulGroupFp8QuantModel
    ],
210
211
    enable_silu_mul_custom_op: bool,
    enable_quant_fp8_custom_op: bool,
212
    force_kernel: FP8ScaledMMLinearKernel | None,
213
):
214
215
    if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
        pytest.skip("NVFP4 is not supported on this GPU.")
216
217
    if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
        pytest.skip("AITER is not supported on this GPU.")
218

219
    torch.set_default_device("cuda")
220
    torch.set_default_dtype(dtype)
221

222
223
    x = torch.rand(num_tokens, hidden_size * 2)

224
    # Reshape pass is needed for the fusion pass to work
225
226
227
228
229
230
231
232
233
    custom_ops = []
    if enable_silu_mul_custom_op:
        custom_ops.append("+silu_and_mul")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=custom_ops,
234
            pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
235
        ),
236
    )
237

238
    with set_current_vllm_config(config):
239
240
241
242
243
244
245
        fusion_passes = [ActivationQuantFusionPass(config)]
        if IS_AITER_FOUND:
            from vllm.compilation.rocm_aiter_fusion import (
                RocmAiterSiluMulFp8GroupQuantFusionPass,
            )

            fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
246

247
        passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
248
        backend = TestBackend(*passes)
249
        model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
250

251
252
        # First dimension dynamic
        torch._dynamo.mark_dynamic(x, 0)
253

254
        result = model(x)
255

256
257
        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)
258

259
260
261
262
263
        # Check that it gives the same answer
        if model_class == TestSiluMulFp8QuantModel:
            atol, rtol = 1e-3, 1e-3
        elif model_class == TestSiluMulNvfp4QuantModel:
            atol, rtol = 1e-1, 1e-1
264
265
        elif model_class == TestSiluMulGroupFp8QuantModel:
            atol, rtol = 5e-2, 5e-2
266
267
268
269

        torch.testing.assert_close(
            result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
        )
270

271
        assert sum([p.matched_count for p in fusion_passes]) == 1
272

273
274
        # In pre-nodes, quant op should be present and fused kernels should not
        backend.check_before_ops(model.ops_in_model_before())
275

276
277
        # In post-nodes, fused kernels should be present and quant op should not
        backend.check_after_ops(model.ops_in_model_after())