test_pos_encoding.py 5.68 KB
Newer Older
1
from typing import Optional, Tuple
2

3
import pytest
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm import pos_encoding_ops
9

10
IS_NEOX_STYLE = [True, False]
11
12
13
14
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52]  # Arbitrary values for testing
15
NUM_TOKENS = [11, 83, 2048]  # Arbitrary values for testing
16
17
SEEDS = [0]

18

19
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
20
21
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
22
23
24
    return torch.cat((-x2, x1), dim=-1)


25
26
27
28
29
30
31
32
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)


def apply_rope(
33
34
35
36
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
37
    is_neox_style: bool,
38
) -> Tuple[torch.Tensor, torch.Tensor]:
39
40
41
    rotate_fn = rotate_neox if is_neox_style else rotate_gptj
    q_embed = (q * cos) + (rotate_fn(q) * sin)
    k_embed = (k * cos) + (rotate_fn(k) * sin)
42
43
44
    return q_embed, k_embed


45
46
class RefRotaryEmbedding(nn.Module):
    """Reference implementation of rotary embedding."""
47
48
49
50

    def __init__(
        self,
        dim: int,
51
52
        is_neox_style: bool,
        max_position_embeddings: int = 8192,
53
54
55
        base: int = 10000,
    ) -> None:
        super().__init__()
56
        self.rotary_dim = dim
57
        self.is_neox_style = is_neox_style
58
59
60
        self.max_position_embeddings = max_position_embeddings

        # Create cos and sin embeddings.
61
        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
62
63
        t = torch.arange(max_position_embeddings).float()
        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
64
65
66
67
        if is_neox_style:
            emb = torch.cat((freqs, freqs), dim=-1)
        else:
            emb = torch.repeat_interleave(freqs, 2, -1)
68
69
70
71
72
73
74
        cos = emb.cos().to(dtype=inv_freq.dtype)
        sin = emb.sin().to(dtype=inv_freq.dtype)
        self.register_buffer("cos_cached", cos, persistent=False)
        self.register_buffer("sin_cached", sin, persistent=False)

    def forward(
        self,
75
76
77
        positions: torch.Tensor,  # [num_tokens]
        query: torch.Tensor,  # [num_tokens, num_heads, head_size]
        key: torch.Tensor,  # [num_tokens, num_heads, head_size]
78
    ) -> Tuple[torch.Tensor, torch.Tensor]:
79
80
81
82
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
83
84
85

        query_rot = query_rot.transpose(0, 1)
        key_rot = key_rot.transpose(0, 1)
86
87
        cos = F.embedding(positions, self.cos_cached)
        sin = F.embedding(positions, self.sin_cached)
88
89
90

        query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
                                        self.is_neox_style)
91
92
93
94
95
96
        query_rot = query_rot.transpose(0, 1).contiguous()
        key_rot = key_rot.transpose(0, 1).contiguous()

        query = torch.cat((query_rot, query_pass), dim=-1)
        key = torch.cat((key_rot, key_pass), dim=-1)

97
98
99
100
        # Output query/key shape: [num_tokens, num_tokens, head_size]
        return query, key


101
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
102
103
104
105
106
107
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
108
@torch.inference_mode()
109
110
def test_rotary_embedding(
    is_neox_style: bool,
111
112
113
    num_tokens: int,
    num_heads: int,
    head_size: int,
114
    rotary_dim: Optional[int],
115
    dtype: torch.dtype,
116
117
    seed: int,
    max_position: int = 8192,
118
119
    base: int = 10000,
) -> None:
120
121
122
123
124
    if rotary_dim is None:
        rotary_dim = head_size
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

125
    positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
126
127
128
    query = torch.randn(num_tokens,
                        num_heads * head_size,
                        dtype=dtype,
129
                        device="cuda")
130
131
132
    key = torch.randn(num_tokens,
                      num_heads * head_size,
                      dtype=dtype,
133
                      device="cuda")
134
135

    # Create the rotary embedding.
136
137
    inv_freq = 1.0 / (base**(
        torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
138
    t = torch.arange(max_position).float()
139
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
140
141
142
143
144
    cos = freqs.cos()
    sin = freqs.sin()
    cos_sin_cache = torch.cat((cos, sin), dim=-1)
    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')

Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
    # Run the kernel. The kernel is in-place, so we need to clone the inputs.
    out_query = query.clone()
    out_key = key.clone()
148
    pos_encoding_ops.rotary_embedding(
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        positions,
150
151
        out_query,
        out_key,
152
        head_size,
153
        cos_sin_cache,
154
        is_neox_style,
155
156
157
    )

    # Run the reference implementation.
158
    ref_rotary_embedding = RefRotaryEmbedding(
159
        dim=rotary_dim,
160
        is_neox_style=is_neox_style,
161
162
        max_position_embeddings=max_position,
        base=base,
163
    ).to(dtype=dtype, device="cuda")
164
165
166
167
168
169
170
171
172
    ref_query, ref_key = ref_rotary_embedding(
        positions,
        query.view(num_tokens, num_heads, head_size),
        key.view(num_tokens, num_heads, head_size),
    )
    ref_query = ref_query.view(num_tokens, num_heads * head_size)
    ref_key = ref_key.view(num_tokens, num_heads * head_size)

    # Compare the results.
173
174
    assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
    assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)