rotary.py 8.67 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
Tri Dao's avatar
Tri Dao committed
2
3
4
5
6
7
8
9

from typing import Tuple
import math

import torch

from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
10
import rotary_emb
Tri Dao's avatar
Tri Dao committed
11
12


Tri Dao's avatar
Tri Dao committed
13
14
15
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)
Tri Dao's avatar
Tri Dao committed
16
17


Tri Dao's avatar
Tri Dao committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def apply_rotary_emb_torch(x, cos, sin):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2)
    """
    rotary_dim = cos.shape[-1] * 2
    assert rotary_dim <= x.shape[-1]
    cos = repeat(cos, 's d -> s 1 (2 d)')
    sin = repeat(sin, 's d -> s 1 (2 d)')
    return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin,
                      x[..., rotary_dim:]], dim=-1)


class ApplyRotaryEmb(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, cos, sin, inplace=False):
        """
            x: (batch_size, seqlen, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
        rotary_dim must be <= headdim
        Apply rotary embedding to the first rotary_dim of x.
        """
        batch, seqlen, nheads, headdim = x.shape
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Alexander Ploshkin's avatar
Alexander Ploshkin committed
46
        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
Tri Dao's avatar
Tri Dao committed
47
48
49
        x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
        out = torch.empty_like(x) if not inplace else x
        o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
Alexander Ploshkin's avatar
Alexander Ploshkin committed
50
51
        rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        if not inplace and rotary_dim < headdim:
            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
        ctx.save_for_backward(cos, sin)
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, headdim = do.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        inplace = ctx.inplace
        do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
        dx = torch.empty_like(do) if not inplace else do
        dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
Alexander Ploshkin's avatar
Alexander Ploshkin committed
68
69
        rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
76
77
78
79
80
        if not inplace and rotary_dim < headdim:
            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
        return dx, None, None, None


apply_rotary_emb_func = ApplyRotaryEmb.apply


class ApplyRotaryEmbQKV_(torch.autograd.Function):

    @staticmethod
Tri Dao's avatar
Tri Dao committed
81
    def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
Tri Dao's avatar
Tri Dao committed
82
83
84
        """
            qkv: (batch_size, seqlen, 3, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
85
            cos_k, sin_k: (seqlen, rotary_dim / 2), optional
Tri Dao's avatar
Tri Dao committed
86
87
88
89
90
91
92
93
94
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of q and k.
        """
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Tri Dao's avatar
Tri Dao committed
95
96
97
        cos_k = cos if cos_k is None else cos_k
        sin_k = sin if sin_k is None else sin_k
        assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
Tri Dao's avatar
Tri Dao committed
98
        q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
Alexander Ploshkin's avatar
Alexander Ploshkin committed
99
100
        rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
Tri Dao's avatar
Tri Dao committed
101
        k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
Tri Dao's avatar
Tri Dao committed
102
103
104
        rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
        ctx.save_for_backward(cos, sin, cos_k, sin_k)
Tri Dao's avatar
Tri Dao committed
105
106
107
108
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
109
        cos, sin, cos_k, sin_k = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
110
111
112
113
        _, seqlen, _, _, headdim = dqkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
Alexander Ploshkin's avatar
Alexander Ploshkin committed
114
115
        rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
Tri Dao's avatar
Tri Dao committed
116
        dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
Tri Dao's avatar
Tri Dao committed
117
118
119
        rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
        return dqkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
120
121
122


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
Tri Dao's avatar
Tri Dao committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

138
139
140
    If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
141
142
    """

Tri Dao's avatar
Tri Dao committed
143
    def __init__(self, dim: int, base=10000, scale_base=0, device=None):
Tri Dao's avatar
Tri Dao committed
144
145
        """
        """
Tri Dao's avatar
Tri Dao committed
146
147
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
Tri Dao's avatar
Tri Dao committed
148
149
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
                                                dtype=torch.float32) / dim))
Tri Dao's avatar
Tri Dao committed
150
        self.register_buffer("inv_freq", inv_freq)
Tri Dao's avatar
Tri Dao committed
151
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
152
153
        scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
                 / (1.4 * dim) if scale_base > 0 else None)
Tri Dao's avatar
Tri Dao committed
154
        self.register_buffer("scale", scale)
Tri Dao's avatar
Tri Dao committed
155

Tri Dao's avatar
Tri Dao committed
156
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
157
158
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
159
160
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
161

162
    def _update_cos_sin_cache(self, x, seqlen_offset=0):
Tri Dao's avatar
Tri Dao committed
163
164
        """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
        """
165
        seqlen = x.shape[1] + seqlen_offset
Tri Dao's avatar
Tri Dao committed
166
167
        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
Tri Dao's avatar
Tri Dao committed
168
        if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
Tri Dao's avatar
Tri Dao committed
169
            or self._cos_cached.dtype != x.dtype):
Tri Dao's avatar
Tri Dao committed
170
171
            self._seq_len_cached = seqlen
            t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
Tri Dao's avatar
Tri Dao committed
172
173
            # Don't do einsum, it converts fp32 to fp16
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Tri Dao's avatar
Tri Dao committed
174
            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
Tri Dao's avatar
Tri Dao committed
175
176
177
178
179
180
            if self.scale is None:
                self._cos_cached = torch.cos(freqs).to(x.dtype)
                self._sin_cached = torch.sin(freqs).to(x.dtype)
            else:
                power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                          - seqlen // 2) / self.scale_base)
Tri Dao's avatar
Tri Dao committed
181
                scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
Tri Dao's avatar
Tri Dao committed
182
183
184
185
186
                # We want the multiplication by scale to happen in fp32
                self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
Tri Dao's avatar
Tri Dao committed
187

188
189
190
191
192
193
    def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
        self._update_cos_sin_cache(qkv, seqlen_offset)
Tri Dao's avatar
Tri Dao committed
194
195
196
197
198
199
200
201
202
        if self.scale is None:
            return apply_rotary_emb_qkv_(
                qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
            )
        else:
            return apply_rotary_emb_qkv_(
                qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:]
            )