test_silu_mul_quant_fusion.py 6.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from typing import cast

5
6
7
8
import pytest
import torch

import vllm.envs as envs
9
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
10
11
12
13
14
15
16
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
# yapf conflicts with isort for this block
# yapf: disable
from vllm.compilation.activation_quant_fusion import (
    FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
17
from vllm.compilation.noop_elimination import NoOpEliminationPass
18
from vllm.compilation.post_cleanup import PostCleanupPass
19
from vllm.config import CompilationConfig, PassConfig, VllmConfig
20
from vllm.model_executor.layers.activation import SiluAndMul
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
22
    GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
23
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
24
    Fp8LinearOp, cutlass_fp8_supported)
25
from vllm.platforms import current_platform
26

27
from ..utils import override_cutlass_fp8_supported
28
29
from .backend import TestBackend

30
31
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
32
33


34
35
36
37
38
39
def is_nvfp4_supported():
    return current_platform.has_device_capability(100)


class TestSiluMulFp8QuantModel(torch.nn.Module):

40
    def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
41
        super().__init__()
42
        self.silu_and_mul = SiluAndMul()
43
        self.wscale = torch.rand(1, dtype=torch.float32)
44
45
        self.scale = torch.rand(1, dtype=torch.float32)

46
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
47

48
49
50
51
52
        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )
53

54
55
    def forward(self, x):
        y = self.silu_and_mul(x)
56
57
58
59
        x2 = self.fp8_linear.apply(y,
                                   self.w,
                                   self.wscale,
                                   input_scale=self.wscale)
60
61
        return x2

62
63
64
65
66
67
68
69
70
    def ops_in_model_before(self):
        return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]

    def ops_in_model_after(self):
        return [FUSED_OPS[kFp8StaticTensorSym]]


class TestSiluMulNvfp4QuantModel(torch.nn.Module):

71
    def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
72
        super().__init__()
73
74
75
76
        from vllm.compilation.activation_quant_fusion import (
            silu_and_mul_nvfp4_quant_supported)
        assert silu_and_mul_nvfp4_quant_supported

77
        self.silu_and_mul = SiluAndMul()
78
79
80
81
82
83
84
85
86

        # create nvfp4 weight
        w = torch.rand((hidden_size, hidden_size))
        self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)

        # get global scale offline
        _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))

        self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
87

88
89
    def forward(self, x):
        y = self.silu_and_mul(x)
90
        y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
91
92
93
        out = cutlass_scaled_fp4_mm(a=y_quant,
                                    b=self.w,
                                    block_scale_a=y_block_scale,
94
95
                                    block_scale_b=self.w_block_scale,
                                    alpha=self.alpha,
96
97
98
99
100
101
102
103
104
105
                                    out_dtype=y.dtype)
        return out

    def ops_in_model_before(self):
        return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]

    def ops_in_model_after(self):
        return [FUSED_OPS[kNvfp4Quant]]


106
107
108
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
109
@pytest.mark.parametrize(
110
111
112
    "model_class",
    cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
         if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
113
114
115
116
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
                         [True, False] if cutlass_fp8_supported() else [True])
117
118
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
                    reason="Only test on CUDA and ROCm")
119
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
120
121
                                   cuda_force_torch):
    if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
122
123
        pytest.skip("Duplicate tests for NVFP4")

124
    torch.set_default_device("cuda")
125
    torch.set_default_dtype(dtype)
126

127
128
    x = torch.rand(num_tokens, hidden_size * 2)

129
130
131
    # Reshape pass is needed for the fusion pass to work
    config = VllmConfig()
    config.compilation_config = CompilationConfig(
132
        pass_config=PassConfig(enable_fusion=True, enable_noop=True))
133
134
    fusion_pass = ActivationQuantFusionPass(config)

135
136
137
138
139
    passes = [
        NoOpEliminationPass(config), fusion_pass,
        PostCleanupPass(config)
    ]
    backend = TestBackend(*passes)
140
    model = model_class(hidden_size=hidden_size,
141
142
                        cuda_force_torch=cuda_force_torch,
                        x=x)
143
144
145
146
147
148
149
150
151
152

    # First dimension dynamic
    torch._dynamo.mark_dynamic(x, 0)

    result = model(x)

    model2 = torch.compile(model, backend=backend)
    result2 = model2(x)

    # Check that it gives the same answer
153
154
155
156
157
    if model_class == TestSiluMulFp8QuantModel:
        atol, rtol = 1e-3, 1e-3
    elif model_class == TestSiluMulNvfp4QuantModel:
        atol, rtol = 1e-1, 1e-1

158
159
    torch.testing.assert_close(result[0].to(dtype=dtype),
                               result2[0].to(dtype=dtype),
160
161
                               atol=atol,
                               rtol=rtol)
162

163
164
    assert fusion_pass.matched_count == 1

165
166
    # In pre-nodes, quant op should be present and fused kernels should not
    backend.check_before_ops(model.ops_in_model_before())
167

168
169
    # In post-nodes, fused kernels should be present and quant op should not
    backend.check_after_ops(model.ops_in_model_after())