test_pos_encoding.py 2.6 KB
Newer Older
1
from typing import Optional
2

3
import pytest
4
import torch
5
from allclose_default import get_default_atol, get_default_rtol
6
from vllm.model_executor.layers.rotary_embedding import get_rope
7

8
IS_NEOX_STYLE = [True, False]
9
10
11
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
12
13
14
NUM_HEADS = [7, 17]  # Arbitrary values for testing
BATCH_SIZES = [1, 5]  # Arbitrary values for testing
SEQ_LENS = [11, 8192]  # Arbitrary values for testing
15
SEEDS = [0]
16
17
18
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
19

20

21
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
22
23
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
24
25
26
27
28
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
29
@pytest.mark.parametrize("device", CUDA_DEVICES)
30
@torch.inference_mode()
31
32
def test_rotary_embedding(
    is_neox_style: bool,
33
34
    batch_size: int,
    seq_len: int,
35
36
    num_heads: int,
    head_size: int,
37
    rotary_dim: Optional[int],
38
    dtype: torch.dtype,
39
    seed: int,
40
    device: str,
41
    max_position: int = 8192,
42
43
    base: int = 10000,
) -> None:
44
45
46
    if rotary_dim is None:
        rotary_dim = head_size
    torch.random.manual_seed(seed)
47
48
49
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
50
51
52
    if rotary_dim is None:
        rotary_dim = head_size
    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
53
    rope = rope.to(dtype=dtype)
54

55
    positions = torch.randint(0, max_position, (batch_size, seq_len))
56
57
    query = torch.randn(batch_size,
                        seq_len,
58
                        num_heads * head_size,
59
                        dtype=dtype)
60
    key = torch.randn_like(query)
61

62
63
64
65
    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
    ref_query, ref_key = rope._forward(positions, query, key)
    out_query, out_key = rope.forward(positions, query, key)
66
    # Compare the results.
67
68
69
70
71
72
73
74
    assert torch.allclose(out_query,
                          ref_query,
                          atol=get_default_atol(out_query),
                          rtol=get_default_rtol(out_query))
    assert torch.allclose(out_key,
                          ref_key,
                          atol=get_default_atol(out_key),
                          rtol=get_default_rtol(out_key))