test_pos_encoding.py 5.08 KB
Newer Older
1
from typing import Optional, Tuple
2

3
import pytest
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm import pos_encoding_ops
9

10
11
12
13
14
15
16
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52]  # Arbitrary values for testing
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
SEEDS = [0]

17
18

def rotate_half(x: torch.Tensor) -> torch.Tensor:
19
20
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RefRotaryEmbeddingNeox(nn.Module):
    """Reference implementation of the GPT-NeoX style rotary embedding."""

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
    ) -> None:
        super().__init__()
45
        self.rotary_dim = dim
46
47
48
        self.max_position_embeddings = max_position_embeddings

        # Create cos and sin embeddings.
49
        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
50
51
52
53
54
55
56
57
58
59
        t = torch.arange(max_position_embeddings).float()
        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos().to(dtype=inv_freq.dtype)
        sin = emb.sin().to(dtype=inv_freq.dtype)
        self.register_buffer("cos_cached", cos, persistent=False)
        self.register_buffer("sin_cached", sin, persistent=False)

    def forward(
        self,
60
61
62
        positions: torch.Tensor,  # [num_tokens]
        query: torch.Tensor,  # [num_tokens, num_heads, head_size]
        key: torch.Tensor,  # [num_tokens, num_heads, head_size]
63
    ) -> Tuple[torch.Tensor, torch.Tensor]:
64

65
66
67
68
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
69
70
71

        query_rot = query_rot.transpose(0, 1)
        key_rot = key_rot.transpose(0, 1)
72
73
        cos = F.embedding(positions, self.cos_cached)
        sin = F.embedding(positions, self.sin_cached)
74
75
76
77
78
79
80
        query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
        query_rot = query_rot.transpose(0, 1).contiguous()
        key_rot = key_rot.transpose(0, 1).contiguous()

        query = torch.cat((query_rot, query_pass), dim=-1)
        key = torch.cat((key_rot, key_pass), dim=-1)

81
82
83
84
        # Output query/key shape: [num_tokens, num_tokens, head_size]
        return query, key


85
86
87
88
89
90
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
91
@torch.inference_mode()
92
def test_rotary_embedding_neox(
93
94
95
    num_tokens: int,
    num_heads: int,
    head_size: int,
96
    rotary_dim: Optional[int],
97
    dtype: torch.dtype,
98
99
    seed: int,
    max_position: int = 8192,
100
101
    base: int = 10000,
) -> None:
102
103
104
105
106
    if rotary_dim is None:
        rotary_dim = head_size
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

107
108
109
110
111
112
113
114
115
    positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
    query = torch.randn(num_tokens,
                        num_heads * head_size,
                        dtype=dtype,
                        device='cuda')
    key = torch.randn(num_tokens,
                      num_heads * head_size,
                      dtype=dtype,
                      device='cuda')
116
117

    # Create the rotary embedding.
118
    inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
119
    t = torch.arange(max_position).float()
120
    freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
121
122
123
124
125
    cos = freqs.cos()
    sin = freqs.sin()
    cos_sin_cache = torch.cat((cos, sin), dim=-1)
    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')

Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
    # Run the kernel. The kernel is in-place, so we need to clone the inputs.
    out_query = query.clone()
    out_key = key.clone()
129
    pos_encoding_ops.rotary_embedding_neox(
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        positions,
131
132
        out_query,
        out_key,
133
        head_size,
134
135
136
137
138
        cos_sin_cache,
    )

    # Run the reference implementation.
    ref_rotary_embedding = RefRotaryEmbeddingNeox(
139
        dim=rotary_dim,
140
141
142
143
144
145
146
147
148
149
150
151
        max_position_embeddings=max_position,
        base=base,
    ).to(dtype=dtype, device='cuda')
    ref_query, ref_key = ref_rotary_embedding(
        positions,
        query.view(num_tokens, num_heads, head_size),
        key.view(num_tokens, num_heads, head_size),
    )
    ref_query = ref_query.view(num_tokens, num_heads * head_size)
    ref_key = ref_key.view(num_tokens, num_heads * head_size)

    # Compare the results.
152
153
    assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
    assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)