test_silu_mul_quant_fusion.py 5.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
import pytest
import torch

import vllm.envs as envs
7
8
9
10
11
12
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
# yapf conflicts with isort for this block
# yapf: disable
from vllm.compilation.activation_quant_fusion import (
    FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
14
from vllm.compilation.noop_elimination import NoOpEliminationPass
15
from vllm.config import CompilationConfig, PassConfig, VllmConfig
16
from vllm.model_executor.layers.activation import SiluAndMul
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
    GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
19
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20
    Fp8LinearOp, cutlass_fp8_supported)
21
from vllm.platforms import current_platform
22

23
from ..utils import override_cutlass_fp8_supported
24
25
from .backend import TestBackend

26
27
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
28
29


30
31
32
33
34
35
def is_nvfp4_supported():
    return current_platform.has_device_capability(100)


class TestSiluMulFp8QuantModel(torch.nn.Module):

36
    def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
37
        super().__init__()
38
        self.silu_and_mul = SiluAndMul()
39
        self.wscale = torch.rand(1, dtype=torch.float32)
40
41
        self.scale = torch.rand(1, dtype=torch.float32)

42
        self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
43

44
45
46
47
48
        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )
49

50
51
    def forward(self, x):
        y = self.silu_and_mul(x)
52
53
54
55
        x2 = self.fp8_linear.apply(y,
                                   self.w,
                                   self.wscale,
                                   input_scale=self.wscale)
56
57
        return x2

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def ops_in_model_before(self):
        return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]

    def ops_in_model_after(self):
        return [FUSED_OPS[kFp8StaticTensorSym]]


class TestSiluMulNvfp4QuantModel(torch.nn.Module):

    def __init__(self, hidden_size: int, **kwargs):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        self.w = torch.randint(256, (hidden_size, hidden_size // 2),
                               dtype=FP4_DTYPE)
        self.wscale = torch.randn(hidden_size,
                                  hidden_size // 16).to(dtype=FP8_DTYPE)
        self.wscale2 = torch.rand(1, dtype=torch.float32)
        self.scale = torch.rand(1, dtype=torch.float32)
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    def forward(self, x):
        y = self.silu_and_mul(x)
        y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale)
        out = cutlass_scaled_fp4_mm(a=y_quant,
                                    b=self.w,
                                    block_scale_a=y_block_scale,
                                    block_scale_b=self.wscale,
                                    alpha=self.scale * self.wscale2,
                                    out_dtype=y.dtype)
        return out

    def ops_in_model_before(self):
        return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]

    def ops_in_model_after(self):
        return [FUSED_OPS[kNvfp4Quant]]


@pytest.mark.parametrize("num_tokens", [64])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize(
    "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
    if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
100
101
102
103
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
                         [True, False] if cutlass_fp8_supported() else [True])
104
105
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
                    reason="Only test on CUDA and ROCm")
106
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
107
108
                                   cuda_force_torch):
    if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
109
110
        pytest.skip("Duplicate tests for NVFP4")

111
112
113
114
115
116
    torch.set_default_device("cuda")
    torch.set_default_dtype(torch.float16)

    # Reshape pass is needed for the fusion pass to work
    config = VllmConfig()
    config.compilation_config = CompilationConfig(
117
        pass_config=PassConfig(enable_fusion=True, enable_noop=True))
118
119
    fusion_pass = ActivationQuantFusionPass(config)

120
    backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
121
122
    model = model_class(hidden_size=hidden_size,
                        cuda_force_torch=cuda_force_torch)
123
124

    # First dimension dynamic
125
    x = torch.rand(num_tokens, hidden_size * 2)
126
127
128
129
130
131
132
133
134
135
136
137
138
    torch._dynamo.mark_dynamic(x, 0)

    result = model(x)

    model2 = torch.compile(model, backend=backend)
    result2 = model2(x)

    # Check that it gives the same answer
    torch.testing.assert_close(result[0].to(dtype=torch.float16),
                               result2[0].to(dtype=torch.float16),
                               atol=1e-3,
                               rtol=1e-3)

139
140
    # In pre-nodes, quant op should be present and fused kernels should not
    backend.check_before_ops(model.ops_in_model_before())
141

142
143
    # In post-nodes, fused kernels should be present and quant op should not
    backend.check_after_ops(model.ops_in_model_after())