rotary.py 15.3 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

Tri Dao's avatar
Tri Dao committed
3
from typing import Tuple, Optional
Tri Dao's avatar
Tri Dao committed
4
5
6
7
8
9
import math

import torch

from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
10
import rotary_emb
Tri Dao's avatar
Tri Dao committed
11
12


13
14
15
16
17
18
19
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
Tri Dao's avatar
Tri Dao committed
20
21


22
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
23
24
25
26
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2)
    """
27
28
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
29
30
    cos = repeat(cos, 's d -> s 1 (2 d)')
    sin = repeat(sin, 's d -> s 1 (2 d)')
31
32
    return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
                      x[..., ro_dim:]], dim=-1)
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37


class ApplyRotaryEmb(torch.autograd.Function):

    @staticmethod
38
    def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
Tri Dao's avatar
Tri Dao committed
39
40
41
        """
            x: (batch_size, seqlen, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
42
43
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
        rotary_dim must be <= headdim
        Apply rotary embedding to the first rotary_dim of x.
        """
        batch, seqlen, nheads, headdim = x.shape
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Alexander Ploshkin's avatar
Alexander Ploshkin committed
52
        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
53
54
        x_ro = x[..., :rotary_dim]
        x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
55
        out = torch.empty_like(x) if not inplace else x
56
57
58
59
60
61
        out_ro = out[..., :rotary_dim]
        if inplace:
            o1, o2 = x1, x2
        else:
            o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
                      else (out_ro[..., ::2], out_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
62
63
        rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
Tri Dao's avatar
Tri Dao committed
64
65
66
        if not inplace and rotary_dim < headdim:
            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
        ctx.save_for_backward(cos, sin)
67
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, headdim = do.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        inplace = ctx.inplace
78
79
80
        do_ro = do[..., :rotary_dim]
        do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (do_ro[..., ::2], do_ro[..., 1::2]))
Tri Dao's avatar
Tri Dao committed
81
        dx = torch.empty_like(do) if not inplace else do
82
83
84
85
86
87
        if inplace:
            dx1, dx2 = do1, do2
        else:
            dx_ro = dx[..., :rotary_dim]
            dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
                        else (dx_ro[..., ::2], dx_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
88
89
        rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
Tri Dao's avatar
Tri Dao committed
90
91
        if not inplace and rotary_dim < headdim:
            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
92
        return dx, None, None, None, None
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100


apply_rotary_emb_func = ApplyRotaryEmb.apply


class ApplyRotaryEmbQKV_(torch.autograd.Function):

    @staticmethod
101
    def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
Tri Dao's avatar
Tri Dao committed
102
103
104
        """
            qkv: (batch_size, seqlen, 3, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
105
            cos_k, sin_k: (seqlen, rotary_dim / 2), optional
106
107
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
                1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of q and k.
        """
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Tri Dao's avatar
Tri Dao committed
117
118
119
        cos_k = cos if cos_k is None else cos_k
        sin_k = sin if sin_k is None else sin_k
        assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
120
121
        q_ro = qkv[:, :, 0, :, :rotary_dim]
        q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
Alexander Ploshkin's avatar
Alexander Ploshkin committed
122
123
        rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
124
125
        k_ro = qkv[:, :, 1, :, :rotary_dim]
        k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
126
127
128
        rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
        ctx.save_for_backward(cos, sin, cos_k, sin_k)
129
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
130
131
132
133
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
134
        cos, sin, cos_k, sin_k = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
135
136
137
        _, seqlen, _, _, headdim = dqkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
138
139
140
        dq_ro = dqkv[:, :, 0, :, :rotary_dim]
        dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dq_ro[..., ::2], dq_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
141
142
        rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
143
144
145
        dk_ro = dqkv[:, :, 1, :, :rotary_dim]
        dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dk_ro[..., ::2], dk_ro[..., 1::2]))
Tri Dao's avatar
Tri Dao committed
146
147
        rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
148
        return dqkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
149
150
151


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
Tri Dao's avatar
Tri Dao committed
152
153


Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class ApplyRotaryEmbKV_(torch.autograd.Function):

    @staticmethod
    def forward(ctx, kv, cos, sin, interleaved=False):
        """
            kv: (batch_size, seqlen, 2, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
                1st half and 2nd half (GPT-NeoX style).
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of k.
        """
        batch, seqlen, two, nheads, headdim = kv.shape
        assert two == 2
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
        k_ro = kv[:, :, 0, :, :rotary_dim]
        k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
        rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
                                False)  # conj=False since this is the forward pass
        ctx.save_for_backward(cos, sin)
        ctx.interleaved = interleaved
        return kv

    @staticmethod
    def backward(ctx, dkv):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, _, headdim = dkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        dk_ro = dkv[:, :, 0, :, :rotary_dim]
        dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dk_ro[..., ::2], dk_ro[..., 1::2]))
        rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
                                True)  # conj=True since this is the backward pass
        return dkv, None, None, None


apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply


Tri Dao's avatar
Tri Dao committed
199
200
201
202
203
204
205
206
207
208
209
210
211
class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

212
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
213
214
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
215
216
    """

217
218
    def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
                 pos_idx_in_fp32=True, device=None):
Tri Dao's avatar
Tri Dao committed
219
        """
220
221
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
222
223
224
225
226
227
228
229
230
231
            pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
                otherwise they might be in lower precision.
                This option was added because previously (before 2023-07-02), when we construct
                the position indices, we use the dtype of self.inv_freq. In most cases this would
                be fp32, but if the model is trained in pure bf16 (not mixed precision), then
                self.inv_freq would be bf16, and the position indices are also in bf16.
                Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
                embeddings for some positions will coincide.
                To maintain compatibility with models previously trained in pure bf16,
                we add this option.
Tri Dao's avatar
Tri Dao committed
232
        """
Tri Dao's avatar
Tri Dao committed
233
        super().__init__()
234
        self.dim = dim
235
        self.base = float(base)
236
        self.pos_idx_in_fp32 = pos_idx_in_fp32
Tri Dao's avatar
Tri Dao committed
237
        # Generate and save the inverse frequency buffer (non trainable)
238
        inv_freq = self._compute_inv_freq(device)
239
        self.register_buffer("inv_freq", inv_freq, persistent=False)
240
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
241
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
242
        scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
243
                 / (1.4 * dim) if scale_base is not None else None)
244
        self.register_buffer("scale", scale, persistent=False)
Tri Dao's avatar
Tri Dao committed
245

Tri Dao's avatar
Tri Dao committed
246
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
247
248
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
249
250
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
251

252
253
254
255
256
257
    def _compute_inv_freq(self, device=None):
        return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
                                                 dtype=torch.float32) / self.dim))


    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
258
        # Reset the tables if the sequence length has changed,
259
260
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
261
        if (seqlen > self._seq_len_cached or self._cos_cached.device != device
262
263
            or self._cos_cached.dtype != dtype
            or (self.training and self._cos_cached.is_inference())):
Tri Dao's avatar
Tri Dao committed
264
            self._seq_len_cached = seqlen
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
Tri Dao's avatar
Tri Dao committed
282
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
283
            freqs = torch.outer(t, inv_freq)
Tri Dao's avatar
Tri Dao committed
284
            if self.scale is None:
285
286
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
Tri Dao's avatar
Tri Dao committed
287
288
289
            else:
                power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                          - seqlen // 2) / self.scale_base)
Tri Dao's avatar
Tri Dao committed
290
                scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
Tri Dao's avatar
Tri Dao committed
291
                # We want the multiplication by scale to happen in fp32
292
293
294
295
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
Tri Dao's avatar
Tri Dao committed
296

Tri Dao's avatar
Tri Dao committed
297
298
    def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
                seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
299
        """
Tri Dao's avatar
Tri Dao committed
300
301
302
        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
             else it's just q of shape (batch, seqlen, nheads, headdim)
        kv: (batch, seqlen, 2, nheads, headdim)
303
304
305
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
Tri Dao's avatar
Tri Dao committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        seqlen = qkv.shape[1]
        self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
        if kv is None:
            if self.scale is None:
                return apply_rotary_emb_qkv_(
                    qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                    None, None, self.interleaved
                )
            else:
                return apply_rotary_emb_qkv_(
                    qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                    self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
                    self.interleaved
                )
Tri Dao's avatar
Tri Dao committed
320
        else:
Tri Dao's avatar
Tri Dao committed
321
322
323
324
            q = qkv
            q = apply_rotary_emb_func(
                q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                self.interleaved, True
Tri Dao's avatar
Tri Dao committed
325
            )
Tri Dao's avatar
Tri Dao committed
326
327
328
329
330
331
332
333
334
335
336
            if self.scale is None:
                kv = apply_rotary_emb_kv_(
                    kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                    self.interleaved
                )
            else:
                kv = apply_rotary_emb_kv_(
                    kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
                    self.interleaved
                )
            return q, kv