test_functionalization.py 9.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import pytest
import torch

import vllm.envs as envs
8
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
9
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
10
from vllm.compilation.fusion import RMSNormQuantFusionPass
11
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
12
from vllm.compilation.noop_elimination import NoOpEliminationPass
13
from vllm.compilation.post_cleanup import PostCleanupPass
14
15
16
17
18
19
20
from vllm.config import (
    CompilationConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
    set_current_vllm_config,
)
21
22
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
23
24
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
25
26
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
27
28
29

from .backend import TestBackend

30
31
32
33
34
35
36
37
38
39
40
41
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):
    def __init__(self, hidden_size: int = 128):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        self.wscale = torch.rand(1, dtype=torch.float32)
        self.scale = torch.rand(1, dtype=torch.float32)

        if TEST_FP8:
42
            self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
43
44
45
46
47
48
49
50
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )

    def forward(self, x):
        y = self.silu_and_mul(x)
        if TEST_FP8:
51
            x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
52
53
54
55
56
            return x2
        else:
            return y

    def example_inputs(self, num_tokens=32, hidden_size=128):
57
        return (torch.rand(num_tokens, hidden_size * 2),)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.silu_and_mul_quant.default]
        else:
            return [torch.ops._C.silu_and_mul.default]

    def ops_not_in_model(self):
        return []


class TestFusedAddRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size=16, intermediate_size=32):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.gate_proj = torch.nn.Parameter(
76
            torch.empty((intermediate_size, hidden_size))
77
        )
78
        self.norm = RMSNorm(intermediate_size, 1e-05)
79
        self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
80
81
82
83
84
85
86

        torch.nn.init.normal_(self.gate_proj, std=0.02)

        if TEST_FP8:
            self.fp8_linear = Fp8LinearOp(act_quant_static=True)

            self.scale = torch.rand(1, dtype=torch.float32)
87
            self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            self.wscale = torch.rand(1, dtype=torch.float32)

    def forward(self, hidden_states, residual):
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)

        # matrix multiplication
        permute = self.gate_proj.permute(1, 0)
        mm = torch.mm(view, permute)

        # layer normalization
        norm_output, residual_output = self.norm(mm, residual)

        if TEST_FP8:
            # scaled_mm with static input quantization
            fp8_linear_result = self.fp8_linear.apply(
                norm_output,
                self.w,
                self.wscale,
                input_scale=self.scale.to(norm_output.device),
            )

            return fp8_linear_result, residual_output

        else:
            return norm_output, residual_output

    def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
116
117
        hidden_states = torch.randn((batch_size * seq_len, hidden_size))
        residual = torch.randn((batch_size * seq_len, hidden_size))
118
        return (hidden_states, residual)
119

120
121
122
123
124
    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
        else:
            return [torch.ops._C.fused_add_rms_norm.default]
125

126
127
    def ops_not_in_model(self):
        return []
128

129

130
class TestRotaryEmbedding(torch.nn.Module):
131
    def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
132
133
134
135
136
137
138
139
        super().__init__()
        self.head_dim = head_dim
        self.rotary_dim = rotary_dim or head_dim

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_dim,
            max_position=max_position,
140
            rope_parameters={"rope_type": "default", "rope_theta": base},
141
142
143
144
145
146
147
148
        )

    def forward(self, positions, q, k):
        q_rotated, k_rotated = self.rotary_emb(positions, q, k)
        return q_rotated, k_rotated

    def example_inputs(self, num_tokens=32, head_dim=64):
        positions = torch.arange(num_tokens, dtype=torch.long)
149
150
        q = torch.randn(num_tokens, head_dim)
        k = torch.randn(num_tokens, head_dim)
151
152
153
154
155
156
157
158
159
160
        return (positions, q, k)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return []


class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
161
    def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
162
163
164
165
166
        super().__init__()
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.hidden_size = head_dim * num_heads

167
        self.qkv_proj = torch.nn.Linear(
168
            self.hidden_size, self.hidden_size * 3, bias=False
169
        )
170
171
172
173
174

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
175
            rope_parameters={"rope_type": "default", "rope_theta": base},
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        )

    def forward(self, positions, hidden_states):
        # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
        # -> slice_scatter -> split_with_sizes

        qkv = self.qkv_proj(hidden_states)
        split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
        q, k, v = torch.split(qkv, split_sizes, dim=-1)

        q_rotated, k_rotated = self.rotary_emb(positions, q, k)

        qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
        return qkv_updated

    def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
        hidden_size = head_dim * num_heads
        positions = torch.arange(num_tokens, dtype=torch.long)
194
        hidden_states = torch.randn(num_tokens, hidden_size)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        return (positions, hidden_states)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return [torch.ops.aten.slice_scatter.default]


MODELS = [
    TestSiluMul,
    TestFusedAddRMSNorm,
    TestRotaryEmbedding,
    TestRotaryEmbeddingSliceScatter,
209
210
211
]


212
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
213
@pytest.mark.parametrize("model_class", MODELS)
214
@pytest.mark.parametrize("do_fusion", [True, False])
215
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
216
217
218
def test_fix_functionalization(
    model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
219
    torch.set_default_device("cuda")
220
221
222
    torch.set_default_dtype(dtype)

    vllm_config = VllmConfig(
223
        model_config=ModelConfig(dtype=dtype),
224
225
        compilation_config=CompilationConfig(
            custom_ops=["all"],
226
227
228
229
230
            pass_config=PassConfig(
                fuse_norm_quant=do_fusion,
                fuse_act_quant=do_fusion,
                eliminate_noops=True,
            ),
231
        ),
232
    )
233

234
235
236
237
238
239
240
241
242
243
244
245
246
    with set_current_vllm_config(vllm_config):
        assert RMSNorm.enabled()
        noop_pass = NoOpEliminationPass(vllm_config)
        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

        passes = (
            [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
            if do_fusion
            else [noop_pass, cleanup_pass]
        )
        func_pass = FixFunctionalizationPass(vllm_config)
247

248
249
        backend_func = TestBackend(*passes, func_pass)
        backend_no_func = TestBackend(*passes)
250

251
252
253
        model = model_class()
        torch.compile(model, backend=backend_func)(*model.example_inputs())
        torch.compile(model, backend=backend_no_func)(*model.example_inputs())
254

255
        # check if the functionalization pass is applied
256
        for op in model.ops_in_model(do_fusion):
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
            assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None

        # make sure the ops were all de-functionalized
        found = dict()
        for node in backend_func.graph_post_pass.nodes:
            for op in model.ops_in_model(do_fusion):
                if is_func(node, op):
                    found[op] = True
            for op in model.ops_not_in_model():
                if is_func(node, op):
                    found[op] = True
        assert all(found[op] for op in model.ops_in_model(do_fusion))
        assert all(not found.get(op) for op in model.ops_not_in_model())