pos_encoding.py 4.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from cacheflow import pos_encoding_ops


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RefRotaryEmbeddingNeox(nn.Module):
    """Reference implementation of the GPT-NeoX style rotary embedding."""

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
    ) -> None:
        super().__init__()
        self.max_position_embeddings = max_position_embeddings

        # Create cos and sin embeddings.
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
        t = torch.arange(max_position_embeddings).float()
        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos().to(dtype=inv_freq.dtype)
        sin = emb.sin().to(dtype=inv_freq.dtype)
        self.register_buffer("cos_cached", cos, persistent=False)
        self.register_buffer("sin_cached", sin, persistent=False)

    def forward(
        self,
        positions: torch.LongTensor,    # [num_tokens]
        query: torch.Tensor,            # [num_tokens, num_heads, head_size]
        key: torch.Tensor,              # [num_tokens, num_heads, head_size]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        cos = F.embedding(positions, self.cos_cached)
        sin = F.embedding(positions, self.sin_cached)
        query = query.transpose(0, 1)
        key = key.transpose(0, 1)
        query, key = apply_rotary_pos_emb(query, key, cos, sin)
        query = query.transpose(0, 1).contiguous()
        key = key.transpose(0, 1).contiguous()
        # Output query/key shape: [num_tokens, num_tokens, head_size]
        return query, key


@torch.inference_mode()
def test_rotary_embedding_neox(
    num_tokens: int,
    num_heads: int,
    head_size: int,
    max_position: int,
    dtype: torch.dtype,
    base: int = 10000,
) -> None:
    positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
    query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
    key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')

    # Create the rotary embedding.
    inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
    t = torch.arange(max_position).float()
    freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
    cos = freqs.cos()
    sin = freqs.sin()
    cos_sin_cache = torch.cat((cos, sin), dim=-1)
    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')

Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
    # Run the kernel. The kernel is in-place, so we need to clone the inputs.
    out_query = query.clone()
    out_key = key.clone()
91
    pos_encoding_ops.rotary_embedding_neox(
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        positions,
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        out_query,
        out_key,
        cos_sin_cache,
    )

    # Run the reference implementation.
    ref_rotary_embedding = RefRotaryEmbeddingNeox(
        dim=head_size,
        max_position_embeddings=max_position,
        base=base,
    ).to(dtype=dtype, device='cuda')
    ref_query, ref_key = ref_rotary_embedding(
        positions,
        query.view(num_tokens, num_heads, head_size),
        key.view(num_tokens, num_heads, head_size),
    )
    ref_query = ref_query.view(num_tokens, num_heads * head_size)
    ref_key = ref_key.view(num_tokens, num_heads * head_size)

    # Compare the results.
    assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
    assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)


if __name__ == '__main__':
    for dtype in [torch.half, torch.float]:
        for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
            print(f'Running tests for head_size={head_size} and dtype={dtype}')
            test_rotary_embedding_neox(
                num_tokens=2145,
                num_heads=5,
                head_size=head_size,
                max_position=8192,
                dtype=dtype,
            )