test_silu_mul_quant_fusion.py 5.76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from typing import cast

5
6
7
8
import pytest
import torch

import vllm.envs as envs
9
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
10
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
11

12
13
14
# yapf conflicts with isort for this block
# yapf: disable
from vllm.compilation.activation_quant_fusion import (
15
16
17
18
19
    FUSED_OPS,
    SILU_MUL_OP,
    ActivationQuantFusionPass,
)

20
21
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
22
from vllm.compilation.noop_elimination import NoOpEliminationPass
23
from vllm.compilation.post_cleanup import PostCleanupPass
24
from vllm.config import CompilationConfig, PassConfig, VllmConfig
25
from vllm.model_executor.layers.activation import SiluAndMul
26
from vllm.model_executor.layers.quantization.utils.quant_utils import (
27
28
29
30
    GroupShape,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
31
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
32
33
34
    Fp8LinearOp,
    cutlass_fp8_supported,
)
35
from vllm.platforms import current_platform
36

37
from ..utils import override_cutlass_fp8_supported
38
39
from .backend import TestBackend

40
41
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
42
43


44
45
46
47
48
def is_nvfp4_supported():
    return current_platform.has_device_capability(100)


class TestSiluMulFp8QuantModel(torch.nn.Module):
49
    def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
50
        super().__init__()
51
        self.silu_and_mul = SiluAndMul()
52
        self.wscale = torch.rand(1, dtype=torch.float32)
53
54
        self.scale = torch.rand(1, dtype=torch.float32)

55
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
56

57
58
59
60
61
        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )
62

63
64
    def forward(self, x):
        y = self.silu_and_mul(x)
65
        x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
66
67
        return x2

68
69
70
71
72
73
74
75
    def ops_in_model_before(self):
        return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]

    def ops_in_model_after(self):
        return [FUSED_OPS[kFp8StaticTensorSym]]


class TestSiluMulNvfp4QuantModel(torch.nn.Module):
76
    def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
77
        super().__init__()
78
        from vllm.compilation.activation_quant_fusion import (
79
80
81
            silu_and_mul_nvfp4_quant_supported,
        )

82
83
        assert silu_and_mul_nvfp4_quant_supported

84
        self.silu_and_mul = SiluAndMul()
85
86
87
88
89
90
91
92
93

        # create nvfp4 weight
        w = torch.rand((hidden_size, hidden_size))
        self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)

        # get global scale offline
        _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))

        self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
94

95
96
    def forward(self, x):
        y = self.silu_and_mul(x)
97
        y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
98
99
100
101
102
103
104
105
        out = cutlass_scaled_fp4_mm(
            a=y_quant,
            b=self.w,
            block_scale_a=y_block_scale,
            block_scale_b=self.w_block_scale,
            alpha=self.alpha,
            out_dtype=y.dtype,
        )
106
107
108
109
110
111
112
113
114
        return out

    def ops_in_model_before(self):
        return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]

    def ops_in_model_after(self):
        return [FUSED_OPS[kNvfp4Quant]]


115
116
117
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
118
@pytest.mark.parametrize(
119
    "model_class",
120
121
122
123
124
125
126
    cast(
        list[type],
        [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
        if is_nvfp4_supported()
        else [TestSiluMulFp8QuantModel],
    ),
)
127
128
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
129
130
131
132
133
134
135
136
137
@pytest.mark.parametrize(
    "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
    envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
    num_tokens, hidden_size, dtype, model_class, cuda_force_torch
):
138
    if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
139
140
        pytest.skip("Duplicate tests for NVFP4")

141
    torch.set_default_device("cuda")
142
    torch.set_default_dtype(dtype)
143

144
145
    x = torch.rand(num_tokens, hidden_size * 2)

146
147
148
    # Reshape pass is needed for the fusion pass to work
    config = VllmConfig()
    config.compilation_config = CompilationConfig(
149
150
        pass_config=PassConfig(enable_fusion=True, enable_noop=True)
    )
151
152
    fusion_pass = ActivationQuantFusionPass(config)

153
    passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
154
    backend = TestBackend(*passes)
155
    model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
156
157
158
159
160
161
162
163
164
165

    # First dimension dynamic
    torch._dynamo.mark_dynamic(x, 0)

    result = model(x)

    model2 = torch.compile(model, backend=backend)
    result2 = model2(x)

    # Check that it gives the same answer
166
167
168
169
170
    if model_class == TestSiluMulFp8QuantModel:
        atol, rtol = 1e-3, 1e-3
    elif model_class == TestSiluMulNvfp4QuantModel:
        atol, rtol = 1e-1, 1e-1

171
172
173
    torch.testing.assert_close(
        result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
    )
174

175
176
    assert fusion_pass.matched_count == 1

177
178
    # In pre-nodes, quant op should be present and fused kernels should not
    backend.check_before_ops(model.ops_in_model_before())
179

180
181
    # In post-nodes, fused kernels should be present and quant op should not
    backend.check_after_ops(model.ops_in_model_after())