test_silu_mul_quant_fusion.py 8.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import itertools
4

5
6
7
8
import pytest
import torch

import vllm.envs as envs
9
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
10
from vllm._aiter_ops import IS_AITER_FOUND
11
12
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
13
14
15
16
    FUSED_OPS,
    SILU_MUL_OP,
    ActivationQuantFusionPass,
)
17
from vllm.compilation.fusion import QUANT_OPS
18
from vllm.compilation.noop_elimination import NoOpEliminationPass
19
from vllm.compilation.post_cleanup import PostCleanupPass
20
21
22
23
24
25
26
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    PassConfig,
    VllmConfig,
    set_current_vllm_config,
)
27
from vllm.model_executor.layers.activation import SiluAndMul
28
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
29
from vllm.model_executor.layers.quantization.utils.quant_utils import (
30
31
32
33
    GroupShape,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
34
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
35
    Fp8LinearOp,
36
    maybe_create_device_identity,
37
)
38
from vllm.platforms import current_platform
39

40
from ..utils import override_cutlass_fp8_supported
41
42
from .backend import TestBackend

43
44
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
45
46


47
48
49
50
51
def is_nvfp4_supported():
    return current_platform.has_device_capability(100)


class TestSiluMulFp8QuantModel(torch.nn.Module):
52
    def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
53
        super().__init__()
54
        self.silu_and_mul = SiluAndMul()
55
        self.wscale = torch.rand(1, dtype=torch.float32)
56
57
        self.scale = torch.rand(1, dtype=torch.float32)

58
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
59

60
61
62
63
64
        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )
65
66
        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
        self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
67

68
69
    def forward(self, x):
        y = self.silu_and_mul(x)
70
        x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
71
72
        return x2

73
    def ops_in_model_before(self):
74
75
76
77
78
79
80
81
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
            (
                QUANT_OPS[kFp8StaticTensorSym]
                if self.enable_quant_fp8_custom_op
                else torch.ops.aten.reciprocal
            ),
        ]
82
83
84
85
86
87

    def ops_in_model_after(self):
        return [FUSED_OPS[kFp8StaticTensorSym]]


class TestSiluMulNvfp4QuantModel(torch.nn.Module):
88
    def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
89
        super().__init__()
90
        from vllm.compilation.activation_quant_fusion import (
91
92
93
            silu_and_mul_nvfp4_quant_supported,
        )

94
95
        assert silu_and_mul_nvfp4_quant_supported

96
        self.silu_and_mul = SiluAndMul()
97
        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
98
99
100
101
102
103
104
105
106

        # create nvfp4 weight
        w = torch.rand((hidden_size, hidden_size))
        self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)

        # get global scale offline
        _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))

        self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
107

108
109
    def forward(self, x):
        y = self.silu_and_mul(x)
110
        y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
111
112
113
114
115
116
117
118
        out = cutlass_scaled_fp4_mm(
            a=y_quant,
            b=self.w,
            block_scale_a=y_block_scale,
            block_scale_b=self.w_block_scale,
            alpha=self.alpha,
            out_dtype=y.dtype,
        )
119
120
121
        return out

    def ops_in_model_before(self):
122
123
124
125
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
            QUANT_OPS[kNvfp4Quant],
        ]
126
127
128
129
130

    def ops_in_model_after(self):
        return [FUSED_OPS[kNvfp4Quant]]


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
    def __init__(self, hidden_size: int, **kwargs):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(128, 128),
            act_quant_group_shape=GroupShape(1, 128),
            cutlass_block_fp8_supported=False,
            use_aiter_and_is_supported=True,
        )
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()

        scale_hidden_size = (hidden_size + 128 - 1) // 128
        self.wscale = torch.rand(
            (scale_hidden_size, scale_hidden_size), dtype=torch.float32
        )

        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()

    def forward(self, x):
        y = self.silu_and_mul(x)
        x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
        return x2

    def ops_in_model_before(self):
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
        ]

    def ops_in_model_after(self):
        return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]


164
165
166
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
167
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
168
@pytest.mark.parametrize(
169
170
    "model_class, enable_quant_fp8_custom_op, cuda_force_torch",
    list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
171
172
173
174
    + [
        (TestSiluMulNvfp4QuantModel, False, False),
        (TestSiluMulGroupFp8QuantModel, False, False),
    ],
175
)
176
177
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
178
179
180
181
@pytest.mark.skipif(
    envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
182
183
184
    num_tokens: int,
    hidden_size: int,
    dtype: torch.dtype,
185
186
187
188
189
    model_class: type[
        TestSiluMulFp8QuantModel
        | TestSiluMulNvfp4QuantModel
        | TestSiluMulGroupFp8QuantModel
    ],
190
191
192
    enable_silu_mul_custom_op: bool,
    enable_quant_fp8_custom_op: bool,
    cuda_force_torch: bool,
193
):
194
195
    if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
        pytest.skip("NVFP4 is not supported on this GPU.")
196
197
    if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
        pytest.skip("AITER is not supported on this GPU.")
198

199
    torch.set_default_device("cuda")
200
    torch.set_default_dtype(dtype)
201
    maybe_create_device_identity()
202

203
204
    x = torch.rand(num_tokens, hidden_size * 2)

205
    # Reshape pass is needed for the fusion pass to work
206
207
208
209
210
211
212
213
214
    custom_ops = []
    if enable_silu_mul_custom_op:
        custom_ops.append("+silu_and_mul")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=custom_ops,
215
            pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
216
        ),
217
    )
218

219
    with set_current_vllm_config(config):
220
221
222
223
224
225
226
        fusion_passes = [ActivationQuantFusionPass(config)]
        if IS_AITER_FOUND:
            from vllm.compilation.rocm_aiter_fusion import (
                RocmAiterSiluMulFp8GroupQuantFusionPass,
            )

            fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
227

228
        passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
229
230
231
232
        backend = TestBackend(*passes)
        model = model_class(
            hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
        )
233

234
235
        # First dimension dynamic
        torch._dynamo.mark_dynamic(x, 0)
236

237
        result = model(x)
238

239
240
        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)
241

242
243
244
245
246
        # Check that it gives the same answer
        if model_class == TestSiluMulFp8QuantModel:
            atol, rtol = 1e-3, 1e-3
        elif model_class == TestSiluMulNvfp4QuantModel:
            atol, rtol = 1e-1, 1e-1
247
248
        elif model_class == TestSiluMulGroupFp8QuantModel:
            atol, rtol = 5e-2, 5e-2
249
250
251
252

        torch.testing.assert_close(
            result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
        )
253

254
        assert sum([p.matched_count for p in fusion_passes]) == 1
255

256
257
        # In pre-nodes, quant op should be present and fused kernels should not
        backend.check_before_ops(model.ops_in_model_before())
258

259
260
        # In post-nodes, fused kernels should be present and quant op should not
        backend.check_after_ops(model.ops_in_model_after())