test_pos_encoding.py 9.68 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from itertools import accumulate, product
4
from typing import Callable, Dict, List, Optional
5

6
import pytest
7
import torch
8

9
from vllm.model_executor.layers.rotary_embedding import get_rope
10
from vllm.platforms import current_platform
11

12
13
from .allclose_default import get_default_atol, get_default_rtol

14
IS_NEOX_STYLE = [True, False]
15
DTYPES = [torch.half, torch.bfloat16, torch.float]
16
HEAD_SIZES = [64, 80, 112, 120, 256]
17
ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
18
19
NUM_HEADS = [17]  # Arbitrary values for testing
BATCH_SIZES = [5]  # Arbitrary values for testing
20
SEQ_LENS = [11, 8192]  # Arbitrary values for testing
21
SEEDS = [0]
22
23
24
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
25

26

27
28
29
30
31
32
33
34
35
36
37
38
39
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
                           head_size: int) -> tuple[int, ...]:
    return (batch_size, seq_len, num_heads * head_size)


def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
                            head_size: int) -> tuple[int, ...]:
    return (batch_size, seq_len, num_heads, head_size)


TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]


40
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
41
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
42
43
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
44
45
46
47
48
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
49
@pytest.mark.parametrize("device", CUDA_DEVICES)
50
@torch.inference_mode()
51
52
def test_rotary_embedding(
    is_neox_style: bool,
53
    tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
54
55
    batch_size: int,
    seq_len: int,
56
57
    num_heads: int,
    head_size: int,
58
    rotary_dim: Optional[int],
59
    dtype: torch.dtype,
60
    seed: int,
61
    device: str,
62
    max_position: int = 8192,
63
64
    base: int = 10000,
) -> None:
65
66
    if rotary_dim is None:
        rotary_dim = head_size
67

68
    current_platform.seed_everything(seed)
69
    torch.set_default_device(device)
70
71
72
    if rotary_dim is None:
        rotary_dim = head_size
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
73
    rope = rope.to(dtype=dtype)
74

75
    positions = torch.randint(0, max_position, (batch_size, seq_len))
76
77
    query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
    query = torch.randn(query_shape, dtype=dtype)
78
    key = torch.randn_like(query)
79

80
81
    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
82
    ref_query, ref_key = rope.forward_native(positions, query, key)
83
    out_query, out_key = rope.forward(positions, query, key)
84
    # Compare the results.
85
86
87
88
89
90
91
92
    torch.testing.assert_close(out_query,
                               ref_query,
                               atol=get_default_atol(out_query),
                               rtol=get_default_rtol(out_query))
    torch.testing.assert_close(out_key,
                               ref_key,
                               atol=get_default_atol(out_key),
                               rtol=get_default_rtol(out_key))
Terry's avatar
Terry committed
93
94
95


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
96
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
Terry's avatar
Terry committed
97
98
99
100
101
102
103
104
105
106
107
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding(
    is_neox_style: bool,
108
    tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
Terry's avatar
Terry committed
109
110
111
112
113
114
115
116
117
118
119
    batch_size: int,
    seq_len: int,
    num_heads: int,
    head_size: int,
    rotary_dim: Optional[int],
    dtype: torch.dtype,
    seed: int,
    device: str,
    max_position: int = 8192,
    base: int = 10000,
) -> None:
120
    current_platform.seed_everything(seed)
Terry's avatar
Terry committed
121
122
123
124
    torch.set_default_device(device)
    if rotary_dim is None:
        rotary_dim = head_size
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
125
        "rope_type": "linear",
Terry's avatar
Terry committed
126
127
128
129
130
        "factor": (1, )
    })
    rope = rope.to(dtype=dtype)

    positions = torch.randint(0, max_position, (batch_size, seq_len))
131
132
    query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
    query = torch.randn(query_shape, dtype=dtype)
Terry's avatar
Terry committed
133
134
135
136
    key = torch.randn_like(query)

    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
137
    ref_query, ref_key = rope.forward_native(positions, query, key)
Terry's avatar
Terry committed
138
139
140
141
    out_query, out_key = rope.forward(positions,
                                      query,
                                      key,
                                      offsets=torch.zeros(batch_size * seq_len,
142
                                                          dtype=torch.long,
Terry's avatar
Terry committed
143
144
                                                          device=device))
    # Compare the results.
145
146
147
148
149
150
151
152
    torch.testing.assert_close(out_query,
                               ref_query,
                               atol=get_default_atol(out_query),
                               rtol=get_default_rtol(out_query))
    torch.testing.assert_close(out_key,
                               ref_key,
                               atol=get_default_atol(out_key),
                               rtol=get_default_rtol(out_key))
Terry's avatar
Terry committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
    is_neox_style: bool,
    batch_size: int,
    seq_len: int,
    num_heads: int,
    head_size: int,
    rotary_dim: Optional[int],
    dtype: torch.dtype,
    seed: int,
    device: str,
    max_position: int = 8192,
    base: int = 10000,
) -> None:
178
    current_platform.seed_everything(seed)
Terry's avatar
Terry committed
179
180
181
182
183
    torch.set_default_device(device)
    if rotary_dim is None:
        rotary_dim = head_size
    scaling_factors: List[int] = [1, 2, 4]
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
184
        "rope_type": "linear",
Terry's avatar
Terry committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        "factor": tuple(scaling_factors)
    })
    rope = rope.to(dtype=dtype)

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)

    offset_map = torch.tensor(
        list(
            accumulate([0] + [
                max_position * scaling_factor * 2
                for scaling_factor in scaling_factors[:-1]
            ])))
    query_types = torch.randint(0,
                                len(scaling_factors), (batch_size, seq_len),
                                device=device)
    query_offsets = offset_map[query_types]

    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
209
210
    ref_query, ref_key = rope.forward_native(positions, query, key,
                                             query_offsets)
Terry's avatar
Terry committed
211
212
213
    out_query, out_key = rope.forward(positions, query, key,
                                      query_offsets.flatten())
    # Compare the results.
214
215
216
217
218
219
220
221
    torch.testing.assert_close(out_query,
                               ref_query,
                               atol=get_default_atol(out_query),
                               rtol=get_default_rtol(out_query))
    torch.testing.assert_close(out_key,
                               ref_key,
                               atol=get_default_atol(out_key),
                               rtol=get_default_rtol(out_key))
222
223
224
225
226
227


@torch.inference_mode()
def test_rope_module_cache():
    MAX_POSITIONS = [123, 1234]
    BASES = [10000, 1000000]
228
    ROPE_SCALINGS = (None, {
229
        "rope_type": "linear",
230
231
        "factor": (1, )
    }, {
232
        "rope_type": "dynamic",
233
234
235
236
237
        "factor": 1
    })
    settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
                ROPE_SCALINGS, DTYPES)
    rope_setting_id_map: Dict[str, int] = {}
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    for setting in product(*settings):
        head_size, rotary_dim, max_position, base, \
            is_neox_stype, rope_scaling, dtype = setting
        if rotary_dim is None:
            rotary_dim = head_size
        rope = get_rope(head_size, rotary_dim, max_position, base,
                        is_neox_stype, rope_scaling, dtype)
        # different settings cannot share the same rope module
        assert id(rope) not in rope_setting_id_map.values()
        assert all(x.dtype == dtype for x in rope.buffers())
        assert all(x.dtype == dtype for x in rope.parameters())
        rope_setting_id_map[str(setting)] = id(rope)

    for setting in product(*settings):
        head_size, rotary_dim, max_position, base, \
            is_neox_stype, rope_scaling, dtype = setting
        if rotary_dim is None:
            rotary_dim = head_size
        rope = get_rope(head_size, rotary_dim, max_position, base,
                        is_neox_stype, rope_scaling, dtype)
        # check if cache take effect
        assert id(rope) == rope_setting_id_map[str(setting)]