test_pos_encoding.py 8.92 KB
Newer Older
1
from itertools import accumulate, product
2
from typing import Dict, List, Optional
3

4
import pytest
5
import torch
6

7
from vllm.model_executor.layers.rotary_embedding import get_rope
8
from vllm.utils import seed_everything
9

10
11
from .allclose_default import get_default_atol, get_default_rtol

12
IS_NEOX_STYLE = [True, False]
13
DTYPES = [torch.half, torch.bfloat16, torch.float]
Joe's avatar
Joe committed
14
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
15
ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
16
17
18
NUM_HEADS = [7, 17]  # Arbitrary values for testing
BATCH_SIZES = [1, 5]  # Arbitrary values for testing
SEQ_LENS = [11, 8192]  # Arbitrary values for testing
19
SEEDS = [0]
20
21
22
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
23

24

25
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
26
27
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
28
29
30
31
32
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
33
@pytest.mark.parametrize("device", CUDA_DEVICES)
34
@torch.inference_mode()
35
36
def test_rotary_embedding(
    is_neox_style: bool,
37
38
    batch_size: int,
    seq_len: int,
39
40
    num_heads: int,
    head_size: int,
41
    rotary_dim: Optional[int],
42
    dtype: torch.dtype,
43
    seed: int,
44
    device: str,
45
    max_position: int = 8192,
46
47
    base: int = 10000,
) -> None:
48
49
    if rotary_dim is None:
        rotary_dim = head_size
50
51

    seed_everything(seed)
52
    torch.set_default_device(device)
53
54
55
    if rotary_dim is None:
        rotary_dim = head_size
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
56
    rope = rope.to(dtype=dtype)
57

58
    positions = torch.randint(0, max_position, (batch_size, seq_len))
59
60
    query = torch.randn(batch_size,
                        seq_len,
61
                        num_heads * head_size,
62
                        dtype=dtype)
63
    key = torch.randn_like(query)
64

65
66
    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
67
    ref_query, ref_key = rope.forward_native(positions, query, key)
68
    out_query, out_key = rope.forward(positions, query, key)
69
    # Compare the results.
70
71
72
73
74
75
76
77
    torch.testing.assert_close(out_query,
                               ref_query,
                               atol=get_default_atol(out_query),
                               rtol=get_default_rtol(out_query))
    torch.testing.assert_close(out_key,
                               ref_key,
                               atol=get_default_atol(out_key),
                               rtol=get_default_rtol(out_key))
Terry's avatar
Terry committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding(
    is_neox_style: bool,
    batch_size: int,
    seq_len: int,
    num_heads: int,
    head_size: int,
    rotary_dim: Optional[int],
    dtype: torch.dtype,
    seed: int,
    device: str,
    max_position: int = 8192,
    base: int = 10000,
) -> None:
103
    seed_everything(seed)
Terry's avatar
Terry committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    torch.set_default_device(device)
    if rotary_dim is None:
        rotary_dim = head_size
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
        "type": "linear",
        "factor": (1, )
    })
    rope = rope.to(dtype=dtype)

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)

    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
122
    ref_query, ref_key = rope.forward_native(positions, query, key)
Terry's avatar
Terry committed
123
124
125
126
    out_query, out_key = rope.forward(positions,
                                      query,
                                      key,
                                      offsets=torch.zeros(batch_size * seq_len,
127
                                                          dtype=torch.long,
Terry's avatar
Terry committed
128
129
                                                          device=device))
    # Compare the results.
130
131
132
133
134
135
136
137
    torch.testing.assert_close(out_query,
                               ref_query,
                               atol=get_default_atol(out_query),
                               rtol=get_default_rtol(out_query))
    torch.testing.assert_close(out_key,
                               ref_key,
                               atol=get_default_atol(out_key),
                               rtol=get_default_rtol(out_key))
Terry's avatar
Terry committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
    is_neox_style: bool,
    batch_size: int,
    seq_len: int,
    num_heads: int,
    head_size: int,
    rotary_dim: Optional[int],
    dtype: torch.dtype,
    seed: int,
    device: str,
    max_position: int = 8192,
    base: int = 10000,
) -> None:
163
    seed_everything(seed)
Terry's avatar
Terry committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    torch.set_default_device(device)
    if rotary_dim is None:
        rotary_dim = head_size
    scaling_factors: List[int] = [1, 2, 4]
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
        "type": "linear",
        "factor": tuple(scaling_factors)
    })
    rope = rope.to(dtype=dtype)

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)

    offset_map = torch.tensor(
        list(
            accumulate([0] + [
                max_position * scaling_factor * 2
                for scaling_factor in scaling_factors[:-1]
            ])))
    query_types = torch.randint(0,
                                len(scaling_factors), (batch_size, seq_len),
                                device=device)
    query_offsets = offset_map[query_types]

    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
194
195
    ref_query, ref_key = rope.forward_native(positions, query, key,
                                             query_offsets)
Terry's avatar
Terry committed
196
197
198
    out_query, out_key = rope.forward(positions, query, key,
                                      query_offsets.flatten())
    # Compare the results.
199
200
201
202
203
204
205
206
    torch.testing.assert_close(out_query,
                               ref_query,
                               atol=get_default_atol(out_query),
                               rtol=get_default_rtol(out_query))
    torch.testing.assert_close(out_key,
                               ref_key,
                               atol=get_default_atol(out_key),
                               rtol=get_default_rtol(out_key))
207
208
209
210
211
212


@torch.inference_mode()
def test_rope_module_cache():
    MAX_POSITIONS = [123, 1234]
    BASES = [10000, 1000000]
213
214
215
216
217
218
219
220
221
222
    ROPE_SCALINGS = (None, {
        "type": "linear",
        "factor": (1, )
    }, {
        "type": "dynamic",
        "factor": 1
    })
    settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
                ROPE_SCALINGS, DTYPES)
    rope_setting_id_map: Dict[str, int] = {}
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    for setting in product(*settings):
        head_size, rotary_dim, max_position, base, \
            is_neox_stype, rope_scaling, dtype = setting
        if rotary_dim is None:
            rotary_dim = head_size
        rope = get_rope(head_size, rotary_dim, max_position, base,
                        is_neox_stype, rope_scaling, dtype)
        # different settings cannot share the same rope module
        assert id(rope) not in rope_setting_id_map.values()
        assert all(x.dtype == dtype for x in rope.buffers())
        assert all(x.dtype == dtype for x in rope.parameters())
        rope_setting_id_map[str(setting)] = id(rope)

    for setting in product(*settings):
        head_size, rotary_dim, max_position, base, \
            is_neox_stype, rope_scaling, dtype = setting
        if rotary_dim is None:
            rotary_dim = head_size
        rope = get_rope(head_size, rotary_dim, max_position, base,
                        is_neox_stype, rope_scaling, dtype)
        # check if cache take effect
        assert id(rope) == rope_setting_id_map[str(setting)]