test_functionalization.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import copy

6
7
8
import pytest
import torch

9
10
11
12
13
14
15
16
17
18
19
20
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer
from vllm.compilation.passes.fusion.act_quant_fusion import (
    ActivationQuantFusionPass,
)
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.passes.utility.fix_functionalization import (
    FixFunctionalizationPass,
)
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
21
22
23
24
25
26
27
from vllm.config import (
    CompilationConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
    set_current_vllm_config,
)
28
29
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
30
31
32
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
33
34
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
35
from vllm.utils.torch_utils import direct_register_custom_op
36

37
38
39
40
41
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):
42
43
    quant_key = kFp8StaticTensorSym

44
45
46
47
    def __init__(self, hidden_size: int = 128):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        if TEST_FP8:
48
49
50
51
            self.fp8_linear = TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.quant_key,
                weight_quant_key=self.quant_key,
52
53
54
55
56
            )

    def forward(self, x):
        y = self.silu_and_mul(x)
        if TEST_FP8:
57
            return self.fp8_linear(y)
58
59
60
61
        else:
            return y

    def example_inputs(self, num_tokens=32, hidden_size=128):
62
        return (torch.rand(num_tokens, hidden_size * 2),)
63
64
65
66
67
68
69
70
71
72
73
74

    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.silu_and_mul_quant.default]
        else:
            return [torch.ops._C.silu_and_mul.default]

    def ops_not_in_model(self):
        return []


class TestFusedAddRMSNorm(torch.nn.Module):
75
76
    quant_key = kFp8StaticTensorSym

77
78
79
80
81
82
    def __init__(self, hidden_size=16, intermediate_size=32):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.gate_proj = torch.nn.Parameter(
83
            torch.empty((intermediate_size, hidden_size))
84
        )
85
        self.norm = RMSNorm(intermediate_size, 1e-05)
86
        self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
87
88
89
90

        torch.nn.init.normal_(self.gate_proj, std=0.02)

        if TEST_FP8:
91
92
93
94
95
            self.fp8_linear = TestFP8Layer(
                weight_shape=(hidden_size, intermediate_size),
                activation_quant_key=self.quant_key,
                weight_quant_key=self.quant_key,
            )
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    def forward(self, hidden_states, residual):
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)

        # matrix multiplication
        permute = self.gate_proj.permute(1, 0)
        mm = torch.mm(view, permute)

        # layer normalization
        norm_output, residual_output = self.norm(mm, residual)

        if TEST_FP8:
            # scaled_mm with static input quantization
110
            fp8_linear_result = self.fp8_linear(norm_output)
111
112
113
114
115
116
117

            return fp8_linear_result, residual_output

        else:
            return norm_output, residual_output

    def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
118
119
        hidden_states = torch.randn((batch_size * seq_len, hidden_size))
        residual = torch.randn((batch_size * seq_len, hidden_size))
120
        return (hidden_states, residual)
121

122
123
124
125
126
    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
        else:
            return [torch.ops._C.fused_add_rms_norm.default]
127

128
129
    def ops_not_in_model(self):
        return []
130

131

132
class TestRotaryEmbedding(torch.nn.Module):
133
    def __init__(self, head_dim=64, max_position=2048, base=10000):
134
135
136
137
138
139
        super().__init__()
        self.head_dim = head_dim

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
140
            rope_parameters={"rope_type": "default", "rope_theta": base},
141
142
143
144
145
146
147
148
        )

    def forward(self, positions, q, k):
        q_rotated, k_rotated = self.rotary_emb(positions, q, k)
        return q_rotated, k_rotated

    def example_inputs(self, num_tokens=32, head_dim=64):
        positions = torch.arange(num_tokens, dtype=torch.long)
149
150
        q = torch.randn(num_tokens, head_dim)
        k = torch.randn(num_tokens, head_dim)
151
152
153
154
155
156
157
158
159
160
        return (positions, q, k)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return []


class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
161
    def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
162
163
164
165
166
        super().__init__()
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.hidden_size = head_dim * num_heads

167
        self.qkv_proj = torch.nn.Linear(
168
            self.hidden_size, self.hidden_size * 3, bias=False
169
        )
170
171
172
173

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
174
            rope_parameters={"rope_type": "default", "rope_theta": base},
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        )

    def forward(self, positions, hidden_states):
        # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
        # -> slice_scatter -> split_with_sizes

        qkv = self.qkv_proj(hidden_states)
        split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
        q, k, v = torch.split(qkv, split_sizes, dim=-1)

        q_rotated, k_rotated = self.rotary_emb(positions, q, k)

        qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
        return qkv_updated

    def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
        hidden_size = head_dim * num_heads
        positions = torch.arange(num_tokens, dtype=torch.long)
193
        hidden_states = torch.randn(num_tokens, hidden_size)
194
195
196
197
198
199
200
201
202
        return (positions, hidden_states)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return [torch.ops.aten.slice_scatter.default]


203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class TestFunctionWithMutatedArgsAndReturn(torch.nn.Module):
    OP_REGISTERED = False

    def __init__(self):
        super().__init__()
        self.register_test_custom_op()

    @classmethod
    def register_test_custom_op(cls):
        if not cls.OP_REGISTERED:

            def function_with_mutated_args_and_return_impl(
                x: torch.Tensor,
            ) -> torch.Tensor:
                ret = x + 1
                x.add_(2)
                return ret

            def function_with_mutated_args_and_return_fake(
                x: torch.Tensor,
            ) -> torch.Tensor:
                return torch.empty_like(x)

            direct_register_custom_op(
                op_name="function_with_mutated_args_and_return",
                op_func=function_with_mutated_args_and_return_impl,
                mutates_args=["x"],
                fake_impl=function_with_mutated_args_and_return_fake,
            )

            cls.OP_REGISTERED = True

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Clone x to avoid mutating the original tensor
        ret = torch.ops.vllm.function_with_mutated_args_and_return(x)
        return x, ret

    def example_inputs(self, num_tokens=32):
        hidden_states = torch.randn(num_tokens)
        return (hidden_states,)

    def ops_in_model(self, do_fusion):
        return [torch.ops.vllm.function_with_mutated_args_and_return.default]

    def ops_not_in_model(self):
        return []


MODELS_AND_DO_FUSION = {
    TestSiluMul: [True, False],
    TestFusedAddRMSNorm: [True, False],
    TestRotaryEmbedding: [False],
    TestRotaryEmbeddingSliceScatter: [False],
    TestFunctionWithMutatedArgsAndReturn: [False],
}
258
259


260
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
261
262
263
264
265
266
267
268
269
270
271
272
@pytest.mark.parametrize(
    "model_class, do_fusion",
    [
        (model_class, do_fusion)
        for model_class, fusions in MODELS_AND_DO_FUSION.items()
        for do_fusion in fusions
    ],
)
@pytest.mark.skipif(
    not current_platform.is_cuda_alike(),
    reason="Only test on cuda and rocm platform",
)
273
274
275
def test_fix_functionalization(
    model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
276
    torch.set_default_device("cuda")
277
    torch.set_default_dtype(dtype)
278
    torch.manual_seed(0)
279
280

    vllm_config = VllmConfig(
281
        model_config=ModelConfig(dtype=dtype),
282
283
        compilation_config=CompilationConfig(
            custom_ops=["all"],
284
285
286
287
288
            pass_config=PassConfig(
                fuse_norm_quant=do_fusion,
                fuse_act_quant=do_fusion,
                eliminate_noops=True,
            ),
289
        ),
290
    )
291

292
293
294
295
296
297
298
299
300
301
302
303
304
    with set_current_vllm_config(vllm_config):
        assert RMSNorm.enabled()
        noop_pass = NoOpEliminationPass(vllm_config)
        fusion_pass = RMSNormQuantFusionPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

        passes = (
            [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
            if do_fusion
            else [noop_pass, cleanup_pass]
        )
        func_pass = FixFunctionalizationPass(vllm_config)
305

306
307
        backend_func = TestBackend(*passes, func_pass)
        backend_no_func = TestBackend(*passes)
308

309
        model = model_class()
310
311
312
313
314
315
316
317
        inputs_func = model.example_inputs()
        inputs_no_func = copy.deepcopy(inputs_func)
        model_func = model_class()
        model_no_func = copy.deepcopy(model_func)
        model_func = torch.compile(model_func, backend=backend_func)
        model_no_func = torch.compile(model_no_func, backend=backend_no_func)
        model_func(*inputs_func)
        model_no_func(*inputs_no_func)
318

319
        # check if the functionalization pass is applied
320
        for op in model.ops_in_model(do_fusion):
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
            assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None

        # make sure the ops were all de-functionalized
        found = dict()
        for node in backend_func.graph_post_pass.nodes:
            for op in model.ops_in_model(do_fusion):
                if is_func(node, op):
                    found[op] = True
            for op in model.ops_not_in_model():
                if is_func(node, op):
                    found[op] = True
        assert all(found[op] for op in model.ops_in_model(do_fusion))
        assert all(not found.get(op) for op in model.ops_not_in_model())
335
336
337
338
339

        # TODO (Rohan138): compare the outputs from model_func and model_no_func
        # currently runs into errors while comparing `TestFusedAddRMSNorm`
        # Linked issue: https://github.com/vllm-project/vllm/issues/34996
        # torch.testing.assert_close(outputs_func, outputs_no_func)