test_functionalization.py 9.57 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import pytest
import torch

import vllm.envs as envs
8
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
9
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
10
from vllm.compilation.fusion import RMSNormQuantFusionPass
11
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
12
from vllm.compilation.noop_elimination import NoOpEliminationPass
13
from vllm.compilation.post_cleanup import PostCleanupPass
14
from vllm.config import CompilationConfig, PassConfig, VllmConfig
15
16
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
19
20
21
22
    GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    Fp8LinearOp)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
23
24
25

from .backend import TestBackend

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):

    def __init__(self, hidden_size: int = 128):
        super().__init__()
        self.silu_and_mul = SiluAndMul()
        self.wscale = torch.rand(1, dtype=torch.float32)
        self.scale = torch.rand(1, dtype=torch.float32)

        if TEST_FP8:
            self.w = torch.rand(hidden_size,
                                hidden_size).to(dtype=FP8_DTYPE).t()
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=True,
                act_quant_group_shape=GroupShape.PER_TENSOR,
            )

    def forward(self, x):
        y = self.silu_and_mul(x)
        if TEST_FP8:
            x2 = self.fp8_linear.apply(y,
                                       self.w,
                                       self.wscale,
                                       input_scale=self.wscale)
            return x2
        else:
            return y

    def example_inputs(self, num_tokens=32, hidden_size=128):
        dtype = torch.float16 if TEST_FP8 else torch.float32
        return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )

    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.silu_and_mul_quant.default]
        else:
            return [torch.ops._C.silu_and_mul.default]

    def ops_not_in_model(self):
        return []


class TestFusedAddRMSNorm(torch.nn.Module):

    def __init__(self, hidden_size=16, intermediate_size=32):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        dtype = torch.float16 if TEST_FP8 else torch.float32

        self.gate_proj = torch.nn.Parameter(
            torch.empty((intermediate_size, hidden_size), dtype=dtype))
        self.norm = RMSNorm(intermediate_size, 1e-05)
        self.norm.weight = torch.nn.Parameter(
            torch.ones(intermediate_size, dtype=dtype))

        torch.nn.init.normal_(self.gate_proj, std=0.02)

        if TEST_FP8:
            self.fp8_linear = Fp8LinearOp(act_quant_static=True)

            self.scale = torch.rand(1, dtype=torch.float32)
            self.w = torch.rand(hidden_size,
                                intermediate_size).to(dtype=FP8_DTYPE).t()
            self.wscale = torch.rand(1, dtype=torch.float32)

    def forward(self, hidden_states, residual):
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)

        # matrix multiplication
        permute = self.gate_proj.permute(1, 0)
        mm = torch.mm(view, permute)

        # layer normalization
        norm_output, residual_output = self.norm(mm, residual)

        if TEST_FP8:
            # scaled_mm with static input quantization
            fp8_linear_result = self.fp8_linear.apply(
                norm_output,
                self.w,
                self.wscale,
                input_scale=self.scale.to(norm_output.device),
            )

            return fp8_linear_result, residual_output

        else:
            return norm_output, residual_output

    def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
        dtype = torch.float16 if TEST_FP8 else torch.float32
        hidden_states = torch.randn((batch_size * seq_len, hidden_size),
                                    dtype=dtype)
        residual = torch.randn((batch_size * seq_len, hidden_size),
                               dtype=dtype)
        return (hidden_states, residual)
128

129
130
131
132
133
    def ops_in_model(self, do_fusion):
        if TEST_FP8 and do_fusion:
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
        else:
            return [torch.ops._C.fused_add_rms_norm.default]
134

135
136
    def ops_not_in_model(self):
        return []
137

138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class TestRotaryEmbedding(torch.nn.Module):

    def __init__(self,
                 head_dim=64,
                 rotary_dim=None,
                 max_position=2048,
                 base=10000):
        super().__init__()
        self.head_dim = head_dim
        self.rotary_dim = rotary_dim or head_dim

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_dim,
            max_position=max_position,
            base=base,
        )

    def forward(self, positions, q, k):
        q_rotated, k_rotated = self.rotary_emb(positions, q, k)
        return q_rotated, k_rotated

    def example_inputs(self, num_tokens=32, head_dim=64):
        dtype = torch.float16
        positions = torch.arange(num_tokens, dtype=torch.long)
        q = torch.randn(num_tokens, head_dim, dtype=dtype)
        k = torch.randn(num_tokens, head_dim, dtype=dtype)
        return (positions, q, k)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return []


class TestRotaryEmbeddingSliceScatter(torch.nn.Module):

    def __init__(self,
                 head_dim=64,
                 num_heads=4,
                 max_position=2048,
                 base=10000):
        super().__init__()
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.hidden_size = head_dim * num_heads

        self.qkv_proj = torch.nn.Linear(self.hidden_size,
                                        self.hidden_size * 3,
                                        bias=False,
                                        dtype=torch.float16)

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=base,
        )

    def forward(self, positions, hidden_states):
        # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
        # -> slice_scatter -> split_with_sizes

        qkv = self.qkv_proj(hidden_states)
        split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
        q, k, v = torch.split(qkv, split_sizes, dim=-1)

        q_rotated, k_rotated = self.rotary_emb(positions, q, k)

        qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
        return qkv_updated

    def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
        dtype = torch.float16
        hidden_size = head_dim * num_heads
        positions = torch.arange(num_tokens, dtype=torch.long)
        hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
        return (positions, hidden_states)

    def ops_in_model(self, do_fusion):
        return [torch.ops._C.rotary_embedding.default]

    def ops_not_in_model(self):
        return [torch.ops.aten.slice_scatter.default]


MODELS = [
    TestSiluMul,
    TestFusedAddRMSNorm,
    TestRotaryEmbedding,
    TestRotaryEmbeddingSliceScatter,
231
232
233
]


234
@pytest.mark.parametrize("model_class", MODELS)
235
236
237
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
                    reason="Only test on CUDA")
238
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
239
240
    torch.set_default_device("cuda")

241
    vllm_config = VllmConfig()
242
243
    vllm_config.compilation_config = CompilationConfig(
        pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
244
    noop_pass = NoOpEliminationPass(vllm_config)
245
246
    fusion_pass = RMSNormQuantFusionPass(vllm_config)
    cleanup_pass = PostCleanupPass(vllm_config)
247
    act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
248

249
250
    passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
              if do_fusion else [noop_pass, cleanup_pass])
251
    func_pass = FixFunctionalizationPass(vllm_config)
252

253
254
255
    backend_func = TestBackend(*passes, func_pass)
    backend_no_func = TestBackend(*passes)

256
257
258
259
260
261
    model = model_class()
    torch.compile(model, backend=backend_func)(*model.example_inputs())
    torch.compile(model, backend=backend_no_func)(*model.example_inputs())

    # check if the functionalization pass is applied
    for op in model.ops_in_model(do_fusion):
262
        find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
263
264
        assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
                is None)  # noqa: E501
265
266
267
268

    # make sure the ops were all de-functionalized
    found = dict()
    for node in backend_func.graph_post_pass.nodes:
269
270
271
272
        for op in model.ops_in_model(do_fusion):
            if is_func(node, op):
                found[op] = True
        for op in model.ops_not_in_model():
273
274
            if is_func(node, op):
                found[op] = True
275
276
    assert all(found[op] for op in model.ops_in_model(do_fusion))
    assert all(not found.get(op) for op in model.ops_not_in_model())