test_functionalization.py 9.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import pytest
import torch

import vllm.envs as envs
8
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
9
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
10
from vllm.compilation.fusion import RMSNormQuantFusionPass
11
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
12
from vllm.compilation.noop_elimination import NoOpEliminationPass
13
from vllm.compilation.post_cleanup import PostCleanupPass
14
15
16
17
from vllm.config import (
    CompilationConfig,
    ModelConfig,
    PassConfig,
18
    RendererConfig,
19
20
21
    VllmConfig,
    set_current_vllm_config,
)
22
23
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
26
27
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
28
29
30

from .backend import TestBackend

31
32
33
34
35
36
37
38
39
40
41
42
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):
    def __init__(self, hidden_size: int = 128):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        self.wscale = torch.rand(1, dtype=torch.float32)
        self.scale = torch.rand(1, dtype=torch.float32)

        if TEST_FP8:
43
            self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
44
45
46
47
48
49
50
51
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )

    def forward(self, x):
        y = self.silu_and_mul(x)
        if TEST_FP8:
52
            x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
53
54
55
56
57
            return x2
        else:
            return y

    def example_inputs(self, num_tokens=32, hidden_size=128):
58
        return (torch.rand(num_tokens, hidden_size * 2),)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.silu_and_mul_quant.default]
        else:
            return [torch.ops._C.silu_and_mul.default]

    def ops_not_in_model(self):
        return []


class TestFusedAddRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size=16, intermediate_size=32):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.gate_proj = torch.nn.Parameter(
77
            torch.empty((intermediate_size, hidden_size))
78
        )
79
        self.norm = RMSNorm(intermediate_size, 1e-05)
80
        self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
81
82
83
84
85
86
87

        torch.nn.init.normal_(self.gate_proj, std=0.02)

        if TEST_FP8:
            self.fp8_linear = Fp8LinearOp(act_quant_static=True)

            self.scale = torch.rand(1, dtype=torch.float32)
88
            self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            self.wscale = torch.rand(1, dtype=torch.float32)

    def forward(self, hidden_states, residual):
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)

        # matrix multiplication
        permute = self.gate_proj.permute(1, 0)
        mm = torch.mm(view, permute)

        # layer normalization
        norm_output, residual_output = self.norm(mm, residual)

        if TEST_FP8:
            # scaled_mm with static input quantization
            fp8_linear_result = self.fp8_linear.apply(
                norm_output,
                self.w,
                self.wscale,
                input_scale=self.scale.to(norm_output.device),
            )

            return fp8_linear_result, residual_output

        else:
            return norm_output, residual_output

    def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
117
118
        hidden_states = torch.randn((batch_size * seq_len, hidden_size))
        residual = torch.randn((batch_size * seq_len, hidden_size))
119
        return (hidden_states, residual)
120

121
122
123
124
125
    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
        else:
            return [torch.ops._C.fused_add_rms_norm.default]
126

127
128
    def ops_not_in_model(self):
        return []
129

130

131
class TestRotaryEmbedding(torch.nn.Module):
132
    def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
133
134
135
136
137
138
139
140
        super().__init__()
        self.head_dim = head_dim
        self.rotary_dim = rotary_dim or head_dim

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_dim,
            max_position=max_position,
141
            rope_parameters={"rope_type": "default", "rope_theta": base},
142
143
144
145
146
147
148
149
        )

    def forward(self, positions, q, k):
        q_rotated, k_rotated = self.rotary_emb(positions, q, k)
        return q_rotated, k_rotated

    def example_inputs(self, num_tokens=32, head_dim=64):
        positions = torch.arange(num_tokens, dtype=torch.long)
150
151
        q = torch.randn(num_tokens, head_dim)
        k = torch.randn(num_tokens, head_dim)
152
153
154
155
156
157
158
159
160
161
        return (positions, q, k)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return []


class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
162
    def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
163
164
165
166
167
        super().__init__()
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.hidden_size = head_dim * num_heads

168
        self.qkv_proj = torch.nn.Linear(
169
            self.hidden_size, self.hidden_size * 3, bias=False
170
        )
171
172
173
174
175

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
176
            rope_parameters={"rope_type": "default", "rope_theta": base},
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        )

    def forward(self, positions, hidden_states):
        # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
        # -> slice_scatter -> split_with_sizes

        qkv = self.qkv_proj(hidden_states)
        split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
        q, k, v = torch.split(qkv, split_sizes, dim=-1)

        q_rotated, k_rotated = self.rotary_emb(positions, q, k)

        qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
        return qkv_updated

    def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
        hidden_size = head_dim * num_heads
        positions = torch.arange(num_tokens, dtype=torch.long)
195
        hidden_states = torch.randn(num_tokens, hidden_size)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        return (positions, hidden_states)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return [torch.ops.aten.slice_scatter.default]


MODELS = [
    TestSiluMul,
    TestFusedAddRMSNorm,
    TestRotaryEmbedding,
    TestRotaryEmbeddingSliceScatter,
210
211
212
]


213
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
214
@pytest.mark.parametrize("model_class", MODELS)
215
@pytest.mark.parametrize("do_fusion", [True, False])
216
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
217
218
219
def test_fix_functionalization(
    model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
220
    torch.set_default_device("cuda")
221
222
    torch.set_default_dtype(dtype)

223
224
    model_config = ModelConfig(dtype=dtype)

225
    vllm_config = VllmConfig(
226
227
        model_config=model_config,
        renderer_config=RendererConfig(model_config=model_config),
228
229
        compilation_config=CompilationConfig(
            custom_ops=["all"],
230
231
232
233
234
            pass_config=PassConfig(
                fuse_norm_quant=do_fusion,
                fuse_act_quant=do_fusion,
                eliminate_noops=True,
            ),
235
        ),
236
    )
237

238
239
240
241
242
243
244
245
246
247
248
249
250
    with set_current_vllm_config(vllm_config):
        assert RMSNorm.enabled()
        noop_pass = NoOpEliminationPass(vllm_config)
        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

        passes = (
            [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
            if do_fusion
            else [noop_pass, cleanup_pass]
        )
        func_pass = FixFunctionalizationPass(vllm_config)
251

252
253
        backend_func = TestBackend(*passes, func_pass)
        backend_no_func = TestBackend(*passes)
254

255
256
257
        model = model_class()
        torch.compile(model, backend=backend_func)(*model.example_inputs())
        torch.compile(model, backend=backend_no_func)(*model.example_inputs())
258

259
        # check if the functionalization pass is applied
260
        for op in model.ops_in_model(do_fusion):
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
            assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None

        # make sure the ops were all de-functionalized
        found = dict()
        for node in backend_func.graph_post_pass.nodes:
            for op in model.ops_in_model(do_fusion):
                if is_func(node, op):
                    found[op] = True
            for op in model.ops_not_in_model():
                if is_func(node, op):
                    found[op] = True
        assert all(found[op] for op in model.ops_in_model(do_fusion))
        assert all(not found.get(op) for op in model.ops_not_in_model())