test_fusion.py 7.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import pytest
import torch

7
import vllm.plugins
8
9
10
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
11
from vllm.compilation.noop_elimination import NoOpEliminationPass
12
from vllm.compilation.post_cleanup import PostCleanupPass
13
14
15
16
17
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    ModelConfig,
    PassConfig,
18
    RendererConfig,
19
20
    VllmConfig,
)
21
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.quantization.utils.quant_utils import (
23
24
25
26
    GroupShape,
    QuantKey,
    ScaleDesc,
)
27
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
28
29
30
31
    Fp8LinearOp,
    cutlass_fp8_supported,
    maybe_create_device_identity,
)
32
from vllm.platforms import current_platform
33

34
from ..utils import override_cutlass_fp8_supported
35
36
from .backend import TestBackend

37
38
FP8_DTYPE = current_platform.fp8_dtype()

39
40
41
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

42
43

class TestModel(torch.nn.Module):
44
45
46
47
48
49
50
51
52
    def __init__(
        self,
        hidden_size: int,
        eps: float,
        static: bool,
        cuda_force_torch: bool,
        *args,
        **kwargs,
    ):
53
        super().__init__(*args, **kwargs)
54
        self.cuda_force_torch = cuda_force_torch
55
56
        self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
57
        group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
58
        quant_scale = ScaleDesc(torch.float32, static, group_shape)
59
        self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
60
        if static:
61
            self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
62
        else:
63
            self.scale = [None for _ in range(3)]
64
65
        self.w = [
            torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
66
            for _ in range(3)
67
        ]
68
69
70
71
72
73

        with override_cutlass_fp8_supported(not cuda_force_torch):
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=static,
                act_quant_group_shape=group_shape,
            )
74

75
76
77
        self.enable_rms_norm_custom_op = self.norm[0].enabled()
        self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()

78
    def forward(self, x):
79
80
        # avoid having graph input be an arg to a pattern directly
        x = resid = torch.relu(x)
81
82
        y = self.norm[0](x)

83
84
85
        x2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
        )
86
87
88
        # make sure resid is used for replacement to work
        y2, resid = self.norm[1](x2, resid)

89
90
91
        x3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
        )
92

93
94
        y3, resid = self.norm[2](x3, resid)  # use resid here

95
96
97
98
99
100
        x4 = self.fp8_linear.apply(
            y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
        )

        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
101
102
103

    def ops_in_model_after(self):
        return [
104
105
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
            FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
106
107
        ]

108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def ops_in_model_before(self):
        return (
            [QUANT_OPS[self.quant_key]]
            if self.enable_quant_fp8_custom_op
            else [torch.ops.aten.reciprocal]
        )

    def ops_in_model_before_partial(self):
        return (
            [RMS_OP, RMS_ADD_OP]
            if self.enable_rms_norm_custom_op
            else [torch.ops.aten.rsqrt]
        )

122
123

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
124
125
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
126
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
127
@pytest.mark.parametrize("static", [True, False])
128
129
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
130
131
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
132
133
134
135
136
137
138
@pytest.mark.parametrize(
    "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
139
140
141
142
143
144
145
146
    dtype,
    hidden_size,
    num_tokens,
    eps,
    static,
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
    cuda_force_torch,
147
):
148
    torch.set_default_device("cuda")
149
150
    torch.set_default_dtype(dtype)
    torch.manual_seed(1)
151
    maybe_create_device_identity()  # needed for certain non-cutlass fp8 paths
152

153
154
155
156
157
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")
158
159

    model_config = ModelConfig(dtype=dtype)
160
    vllm_config = VllmConfig(
161
162
        model_config=model_config,
        renderer_config=RendererConfig(model_config=model_config),
163
        compilation_config=CompilationConfig(
164
            mode=CompilationMode.VLLM_COMPILE,
165
            custom_ops=custom_ops,
166
167
168
            pass_config=PassConfig(
                fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
            ),
169
        ),
170
    )
171
172
    with vllm.config.set_current_vllm_config(vllm_config):
        # Reshape pass is needed for the fusion pass to work
173
        noop_pass = NoOpEliminationPass(vllm_config)
174
175
        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
176

177
        backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
178
        backend2 = TestBackend(noop_pass, cleanup_pass)
179
        model = TestModel(hidden_size, eps, static, cuda_force_torch)
180
181
182
183
184

        # First dimension dynamic
        x = torch.rand(num_tokens, hidden_size)
        torch._dynamo.mark_dynamic(x, 0)

185
186
        model_fused = torch.compile(model, backend=backend)
        result_fused = model_fused(x)
187

188
189
        model_unfused = torch.compile(model, backend=backend2)
        result_unfused = model_unfused(x)
190

191
        if dtype == torch.float16:
192
193
194
195
            ATOL, RTOL = (2e-3, 2e-3)
        else:
            ATOL, RTOL = (1e-2, 1e-2)

196
        torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
197

198
        assert fusion_pass.matched_count == 3
199
        backend.check_before_ops(model.ops_in_model_before())
200
201
202
        backend.check_before_ops(
            model.ops_in_model_before_partial(), fully_replaced=False
        )
203
        backend.check_after_ops(model.ops_in_model_after())
204
205
206
207
208
209
210
211
212
213

        # If RMSNorm custom op is disabled (native/torch impl used),
        # there's a risk that the fused add doesn't get included in the
        # replacement and only the rms part gets fused with quant.
        # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
        if not enable_rms_norm_custom_op:
            n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
            # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
            assert n_add_nodes(backend.graph_pre_pass) == 7
            assert n_add_nodes(backend.graph_post_pass) == 2