test_mrope.py 7.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch
from transformers import AutoConfig

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
                       head_size: int, max_position_embeddings: int,
                       dtype: torch.dtype, device: torch.device):
    """Generate test data for given configuration."""
    # Create 2D positions (3, num_tokens) for multimodal case
    positions = torch.randint(0,
                              max_position_embeddings // 4, (3, num_tokens),
                              device=device)

    # Create query and key tensors
    query = torch.randn(num_tokens,
                        num_q_heads * head_size,
                        dtype=dtype,
                        device=device)
    key = torch.randn(num_tokens,
                      num_kv_heads * head_size,
                      dtype=dtype,
                      device=device)

    return positions, query, key


def unroll_model_tp_dict(model_tp_dict):
    return [(model_name, tp_size)
            for model_name, tp_sizes in model_tp_dict.items()
            for tp_size in tp_sizes]


model_tp_dict = {
    "Qwen/Qwen2-VL-7B-Instruct": [1, 2],
    "Qwen/Qwen2-VL-72B-Instruct": [1, 2],
    "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2]
}

# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list = [
    [torch.bfloat16, 1e-5, 1.6e-2],
]

num_tokens_list = [11, 8192]


@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_name, tp_size",
                         unroll_model_tp_dict(model_tp_dict))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):

    config = AutoConfig.from_pretrained(model_name)

    # get the model config
    total_num_kv_heads = config.num_key_value_heads
    total_num_heads = config.num_attention_heads
    num_heads = total_num_heads // tp_size
    num_kv_heads = max(1, total_num_kv_heads // tp_size)
    head_dim = config.hidden_size // total_num_heads
    is_neox_style = True

    rope_theta = config.rope_theta
    max_position = config.max_position_embeddings

    mrope_helper_class = get_rope(
        head_size=head_dim,
        rotary_dim=head_dim,
        max_position=max_position,
        base=rope_theta,
        is_neox_style=is_neox_style,
        rope_scaling=config.rope_scaling,
        dtype=dtype,
    ).to(device=device)

    # create q k v input tensors
    # create rotary pos emb input tensors
    positions, query, key = generate_test_data(num_tokens, num_heads,
                                               num_kv_heads, head_dim,
                                               max_position, dtype, device)

    query_native, key_native = mrope_helper_class.forward_native(
        positions,
        query.clone(),
        key.clone(),
    )

    query_cuda, key_cuda = mrope_helper_class.forward_cuda(
        positions,
        query.clone(),
        key.clone(),
    )

    torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
    torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)


@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
    "model_name, tp_size",
    unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]}))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("num_tokens", [4])
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
                                     num_tokens):
    config = AutoConfig.from_pretrained(model_name)

    # get the model config
    total_num_kv_heads = config.num_key_value_heads
    total_num_heads = config.num_attention_heads
    num_heads = total_num_heads // tp_size
    num_kv_heads = max(1, total_num_kv_heads // tp_size)
    head_dim = config.hidden_size // total_num_heads
    is_neox_style = True
    rope_theta = config.rope_theta
    max_position = config.max_position_embeddings

    mrope_helper_class = get_rope(
        head_size=head_dim,
        rotary_dim=head_dim,
        max_position=max_position,
        base=rope_theta,
        is_neox_style=is_neox_style,
        rope_scaling=config.rope_scaling,
        dtype=dtype,
    ).to(device=device)

    # Generate test data
    positions, query, key = generate_test_data(num_tokens, num_heads,
                                               num_kv_heads, head_dim,
                                               max_position, dtype, device)

    # Create a wrapper that makes the in-place function appear functional
    def functional_forward_cuda(pos, q, k):
        """Wrapper that converts in-place operation to functional style
        
        CUDA Graph does not support in-place operations.
        This wrapper creates working copies of the 
        input tensors and modifies them.
        """
        q_work = q.clone()  # Create working copies
        k_work = k.clone()
        # Your in-place function modifies q_work and k_work
        mrope_helper_class.forward_cuda(pos, q_work, k_work)
        return q_work, k_work  # Return the modified tensors

    # Get reference results
    query_native, key_native = mrope_helper_class.forward_native(
        positions,
        query.clone(),
        key.clone(),
    )

    try:
        compiled_forward_cuda = torch.compile(functional_forward_cuda,
                                              fullgraph=True,
                                              backend="inductor",
                                              mode="reduce-overhead",
                                              dynamic=False)

        # Run compiled version
        query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
            positions,
            query,
            key,
        )

        # Run original version for comparison
        query_cuda = query.clone()
        key_cuda = key.clone()
        mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)

        # Verify results
        torch.testing.assert_close(query_compiled_cuda,
                                   query_cuda,
                                   atol=atol,
                                   rtol=rtol)
        torch.testing.assert_close(key_compiled_cuda,
                                   key_cuda,
                                   atol=atol,
                                   rtol=rtol)
        torch.testing.assert_close(query_compiled_cuda,
                                   query_native,
                                   atol=atol,
                                   rtol=rtol)
        torch.testing.assert_close(key_compiled_cuda,
                                   key_native,
                                   atol=atol,
                                   rtol=rtol)

        print("✓ forward_cuda successfully traced with torch.compile inductor")

    except Exception as e:
        pytest.fail(
            f"forward_cuda failed to trace with torch.compile inductor: {e}")