test_silu_mul_quant_fusion.py 6.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import itertools
4

5
6
7
8
import pytest
import torch

import vllm.envs as envs
9
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
10
11
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
12
13
14
15
    FUSED_OPS,
    SILU_MUL_OP,
    ActivationQuantFusionPass,
)
16
from vllm.compilation.fusion import QUANT_OPS
17
from vllm.compilation.noop_elimination import NoOpEliminationPass
18
from vllm.compilation.post_cleanup import PostCleanupPass
19
20
21
22
23
24
25
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    PassConfig,
    VllmConfig,
    set_current_vllm_config,
)
26
from vllm.model_executor.layers.activation import SiluAndMul
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
28
29
30
31
    GroupShape,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
32
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
33
    Fp8LinearOp,
34
    maybe_create_device_identity,
35
)
36
from vllm.platforms import current_platform
37

38
from ..utils import override_cutlass_fp8_supported
39
40
from .backend import TestBackend

41
42
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
43
44


45
46
47
48
49
def is_nvfp4_supported():
    return current_platform.has_device_capability(100)


class TestSiluMulFp8QuantModel(torch.nn.Module):
50
    def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
51
        super().__init__()
52
        self.silu_and_mul = SiluAndMul()
53
        self.wscale = torch.rand(1, dtype=torch.float32)
54
55
        self.scale = torch.rand(1, dtype=torch.float32)

56
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
57

58
59
60
61
62
        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )
63
64
        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
        self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
65

66
67
    def forward(self, x):
        y = self.silu_and_mul(x)
68
        x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
69
70
        return x2

71
    def ops_in_model_before(self):
72
73
74
75
76
77
78
79
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
            (
                QUANT_OPS[kFp8StaticTensorSym]
                if self.enable_quant_fp8_custom_op
                else torch.ops.aten.reciprocal
            ),
        ]
80
81
82
83
84
85

    def ops_in_model_after(self):
        return [FUSED_OPS[kFp8StaticTensorSym]]


class TestSiluMulNvfp4QuantModel(torch.nn.Module):
86
    def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
87
        super().__init__()
88
        from vllm.compilation.activation_quant_fusion import (
89
90
91
            silu_and_mul_nvfp4_quant_supported,
        )

92
93
        assert silu_and_mul_nvfp4_quant_supported

94
        self.silu_and_mul = SiluAndMul()
95
        self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
96
97
98
99
100
101
102
103
104

        # create nvfp4 weight
        w = torch.rand((hidden_size, hidden_size))
        self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)

        # get global scale offline
        _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))

        self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
105

106
107
    def forward(self, x):
        y = self.silu_and_mul(x)
108
        y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
109
110
111
112
113
114
115
116
        out = cutlass_scaled_fp4_mm(
            a=y_quant,
            b=self.w,
            block_scale_a=y_block_scale,
            block_scale_b=self.w_block_scale,
            alpha=self.alpha,
            out_dtype=y.dtype,
        )
117
118
119
        return out

    def ops_in_model_before(self):
120
121
122
123
        return [
            SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
            QUANT_OPS[kNvfp4Quant],
        ]
124
125
126
127
128

    def ops_in_model_after(self):
        return [FUSED_OPS[kNvfp4Quant]]


129
130
131
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
132
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
133
@pytest.mark.parametrize(
134
135
136
    "model_class, enable_quant_fp8_custom_op, cuda_force_torch",
    list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
    + [(TestSiluMulNvfp4QuantModel, False, False)],
137
)
138
139
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
140
141
142
143
@pytest.mark.skipif(
    envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
144
145
146
147
148
149
150
    num_tokens: int,
    hidden_size: int,
    dtype: torch.dtype,
    model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
    enable_silu_mul_custom_op: bool,
    enable_quant_fp8_custom_op: bool,
    cuda_force_torch: bool,
151
):
152
153
    if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
        pytest.skip("NVFP4 is not supported on this GPU.")
154

155
    torch.set_default_device("cuda")
156
    torch.set_default_dtype(dtype)
157
    maybe_create_device_identity()
158

159
160
    x = torch.rand(num_tokens, hidden_size * 2)

161
    # Reshape pass is needed for the fusion pass to work
162
163
164
165
166
167
168
169
170
171
172
    custom_ops = []
    if enable_silu_mul_custom_op:
        custom_ops.append("+silu_and_mul")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=custom_ops,
            pass_config=PassConfig(enable_fusion=True, enable_noop=True),
        ),
173
    )
174

175
176
    with set_current_vllm_config(config):
        fusion_pass = ActivationQuantFusionPass(config)
177

178
179
180
181
182
        passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
        backend = TestBackend(*passes)
        model = model_class(
            hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
        )
183

184
185
        # First dimension dynamic
        torch._dynamo.mark_dynamic(x, 0)
186

187
        result = model(x)
188

189
190
        model2 = torch.compile(model, backend=backend)
        result2 = model2(x)
191

192
193
194
195
196
197
198
199
200
        # Check that it gives the same answer
        if model_class == TestSiluMulFp8QuantModel:
            atol, rtol = 1e-3, 1e-3
        elif model_class == TestSiluMulNvfp4QuantModel:
            atol, rtol = 1e-1, 1e-1

        torch.testing.assert_close(
            result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
        )
201

202
        assert fusion_pass.matched_count == 1
203

204
205
        # In pre-nodes, quant op should be present and fused kernels should not
        backend.check_before_ops(model.ops_in_model_before())
206

207
208
        # In post-nodes, fused kernels should be present and quant op should not
        backend.check_after_ops(model.ops_in_model_after())