rotary.py 16.1 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3

import math
Tri Dao's avatar
Tri Dao committed
4
from typing import Optional, Tuple
Tri Dao's avatar
Tri Dao committed
5

Tri Dao's avatar
Tri Dao committed
6
import rotary_emb
Tri Dao's avatar
Tri Dao committed
7
8
9
10
import torch
from einops import rearrange, repeat


11
12
13
14
15
16
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
Tri Dao's avatar
Tri Dao committed
17
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
Tri Dao's avatar
Tri Dao committed
18
19


20
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
21
22
23
24
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2)
    """
25
26
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
27
28
29
30
31
32
    cos = repeat(cos, "s d -> s 1 (2 d)")
    sin = repeat(sin, "s d -> s 1 (2 d)")
    return torch.cat(
        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
        dim=-1,
    )
Tri Dao's avatar
Tri Dao committed
33
34
35
36


class ApplyRotaryEmb(torch.autograd.Function):
    @staticmethod
37
    def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
Tri Dao's avatar
Tri Dao committed
38
39
40
        """
            x: (batch_size, seqlen, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
41
42
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
43
44
45
46
47
48
49
50
        rotary_dim must be <= headdim
        Apply rotary embedding to the first rotary_dim of x.
        """
        batch, seqlen, nheads, headdim = x.shape
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Alexander Ploshkin's avatar
Alexander Ploshkin committed
51
        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
52
53
        x_ro = x[..., :rotary_dim]
        x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
54
        out = torch.empty_like(x) if not inplace else x
55
56
57
58
        out_ro = out[..., :rotary_dim]
        if inplace:
            o1, o2 = x1, x2
        else:
Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            o1, o2 = (
                out_ro.chunk(2, dim=-1)
                if not interleaved
                else (out_ro[..., ::2], out_ro[..., 1::2])
            )
        rotary_emb.apply_rotary(
            x1,
            x2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            o1,
            o2,
            False,
        )
Tri Dao's avatar
Tri Dao committed
73
74
75
        if not inplace and rotary_dim < headdim:
            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
        ctx.save_for_backward(cos, sin)
76
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
85
86
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, headdim = do.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        inplace = ctx.inplace
87
        do_ro = do[..., :rotary_dim]
Tri Dao's avatar
Tri Dao committed
88
89
90
        do1, do2 = (
            do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
        )
Tri Dao's avatar
Tri Dao committed
91
        dx = torch.empty_like(do) if not inplace else do
92
93
94
95
        if inplace:
            dx1, dx2 = do1, do2
        else:
            dx_ro = dx[..., :rotary_dim]
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            dx1, dx2 = (
                dx_ro.chunk(2, dim=-1)
                if not ctx.interleaved
                else (dx_ro[..., ::2], dx_ro[..., 1::2])
            )
        rotary_emb.apply_rotary(
            do1,
            do2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            dx1,
            dx2,
            True,
        )
Tri Dao's avatar
Tri Dao committed
110
111
        if not inplace and rotary_dim < headdim:
            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
112
        return dx, None, None, None, None
Tri Dao's avatar
Tri Dao committed
113
114
115
116
117
118
119


apply_rotary_emb_func = ApplyRotaryEmb.apply


class ApplyRotaryEmbQKV_(torch.autograd.Function):
    @staticmethod
120
    def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
Tri Dao's avatar
Tri Dao committed
121
122
123
        """
            qkv: (batch_size, seqlen, 3, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
124
            cos_k, sin_k: (seqlen, rotary_dim / 2), optional
125
126
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
                1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
135
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of q and k.
        """
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Tri Dao's avatar
Tri Dao committed
136
137
138
        cos_k = cos if cos_k is None else cos_k
        sin_k = sin if sin_k is None else sin_k
        assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
139
140
        q_ro = qkv[:, :, 0, :, :rotary_dim]
        q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
148
149
        rotary_emb.apply_rotary(
            q1,
            q2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            q1,
            q2,
            False,
        )
150
151
        k_ro = qkv[:, :, 1, :, :rotary_dim]
        k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
        rotary_emb.apply_rotary(
            k1,
            k2,
            rearrange(cos_k[:seqlen], "s d -> s 1 d"),
            rearrange(sin_k[:seqlen], "s d -> s 1 d"),
            k1,
            k2,
            False,
        )
Tri Dao's avatar
Tri Dao committed
161
        ctx.save_for_backward(cos, sin, cos_k, sin_k)
162
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
163
164
165
166
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
167
        cos, sin, cos_k, sin_k = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
168
169
170
        _, seqlen, _, _, headdim = dqkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
171
        dq_ro = dqkv[:, :, 0, :, :rotary_dim]
Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
177
178
179
180
181
182
183
        dq1, dq2 = (
            dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
        )
        rotary_emb.apply_rotary(
            dq1,
            dq2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            dq1,
            dq2,
            True,
        )
184
        dk_ro = dqkv[:, :, 1, :, :rotary_dim]
Tri Dao's avatar
Tri Dao committed
185
186
187
188
189
190
191
192
193
194
195
196
        dk1, dk2 = (
            dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
        )
        rotary_emb.apply_rotary(
            dk1,
            dk2,
            rearrange(cos_k[:seqlen], "s d -> s 1 d"),
            rearrange(sin_k[:seqlen], "s d -> s 1 d"),
            dk1,
            dk2,
            True,
        )
197
        return dqkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
198
199
200


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
Tri Dao's avatar
Tri Dao committed
201
202


Tri Dao's avatar
Tri Dao committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class ApplyRotaryEmbKV_(torch.autograd.Function):
    @staticmethod
    def forward(ctx, kv, cos, sin, interleaved=False):
        """
            kv: (batch_size, seqlen, 2, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
                1st half and 2nd half (GPT-NeoX style).
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of k.
        """
        batch, seqlen, two, nheads, headdim = kv.shape
        assert two == 2
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
        k_ro = kv[:, :, 0, :, :rotary_dim]
        k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
222
223
224
225
226
227
228
229
230
        rotary_emb.apply_rotary(
            k1,
            k2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            k1,
            k2,
            False,
        )  # conj=False since this is the forward pass
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
236
237
238
239
240
241
        ctx.save_for_backward(cos, sin)
        ctx.interleaved = interleaved
        return kv

    @staticmethod
    def backward(ctx, dkv):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, _, headdim = dkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        dk_ro = dkv[:, :, 0, :, :rotary_dim]
Tri Dao's avatar
Tri Dao committed
242
243
244
245
246
247
248
249
250
251
252
253
        dk1, dk2 = (
            dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
        )
        rotary_emb.apply_rotary(
            dk1,
            dk2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            dk1,
            dk2,
            True,
        )  # conj=True since this is the backward pass
Tri Dao's avatar
Tri Dao committed
254
255
256
257
258
259
        return dkv, None, None, None


apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply


Tri Dao's avatar
Tri Dao committed
260
261
262
263
264
265
266
267
268
269
270
271
272
class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

273
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
274
275
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
276
277
    """

Tri Dao's avatar
Tri Dao committed
278
279
280
281
282
283
284
285
286
    def __init__(
        self,
        dim: int,
        base=10000.0,
        interleaved=False,
        scale_base=None,
        pos_idx_in_fp32=True,
        device=None,
    ):
Tri Dao's avatar
Tri Dao committed
287
        """
Tri Dao's avatar
Tri Dao committed
288
289
290
291
292
293
294
295
296
297
298
299
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
            otherwise they might be in lower precision.
            This option was added because previously (before 2023-07-02), when we construct
            the position indices, we use the dtype of self.inv_freq. In most cases this would
            be fp32, but if the model is trained in pure bf16 (not mixed precision), then
            self.inv_freq would be bf16, and the position indices are also in bf16.
            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
            embeddings for some positions will coincide.
            To maintain compatibility with models previously trained in pure bf16,
            we add this option.
Tri Dao's avatar
Tri Dao committed
300
        """
Tri Dao's avatar
Tri Dao committed
301
        super().__init__()
302
        self.dim = dim
303
        self.base = float(base)
304
        self.pos_idx_in_fp32 = pos_idx_in_fp32
Tri Dao's avatar
Tri Dao committed
305
        # Generate and save the inverse frequency buffer (non trainable)
306
        inv_freq = self._compute_inv_freq(device)
307
        self.register_buffer("inv_freq", inv_freq, persistent=False)
308
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
309
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
310
311
312
313
314
        scale = (
            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
            if scale_base is not None
            else None
        )
315
        self.register_buffer("scale", scale, persistent=False)
Tri Dao's avatar
Tri Dao committed
316

Tri Dao's avatar
Tri Dao committed
317
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
318
319
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
320
321
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
322

323
    def _compute_inv_freq(self, device=None):
Tri Dao's avatar
Tri Dao committed
324
325
326
327
        return 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
        )
328
329

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
330
        # Reset the tables if the sequence length has changed,
331
332
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
Tri Dao's avatar
Tri Dao committed
333
334
335
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
336
            or self._cos_cached.dtype != dtype
Tri Dao's avatar
Tri Dao committed
337
338
            or (self.training and self._cos_cached.is_inference())
        ):
Tri Dao's avatar
Tri Dao committed
339
            self._seq_len_cached = seqlen
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
Tri Dao's avatar
Tri Dao committed
357
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
358
            freqs = torch.outer(t, inv_freq)
Tri Dao's avatar
Tri Dao committed
359
            if self.scale is None:
360
361
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
Tri Dao's avatar
Tri Dao committed
362
            else:
Tri Dao's avatar
Tri Dao committed
363
364
365
366
367
                power = (
                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                    - seqlen // 2
                ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
Tri Dao's avatar
Tri Dao committed
368
                # We want the multiplication by scale to happen in fp32
369
370
371
372
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
Tri Dao's avatar
Tri Dao committed
373

Tri Dao's avatar
Tri Dao committed
374
375
376
    def forward(
        self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
377
        """
Tri Dao's avatar
Tri Dao committed
378
379
380
        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
             else it's just q of shape (batch, seqlen, nheads, headdim)
        kv: (batch, seqlen, 2, nheads, headdim)
381
382
383
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
Tri Dao's avatar
Tri Dao committed
384
385
386
387
388
        seqlen = qkv.shape[1]
        self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
        if kv is None:
            if self.scale is None:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
389
390
391
392
393
394
                    qkv,
                    self._cos_cached[seqlen_offset:],
                    self._sin_cached[seqlen_offset:],
                    None,
                    None,
                    self.interleaved,
Tri Dao's avatar
Tri Dao committed
395
396
397
                )
            else:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
                    qkv,
                    self._cos_cached[seqlen_offset:],
                    self._sin_cached[seqlen_offset:],
                    self._cos_k_cached[seqlen_offset:],
                    self._sin_k_cached[seqlen_offset:],
                    self.interleaved,
Tri Dao's avatar
Tri Dao committed
404
                )
Tri Dao's avatar
Tri Dao committed
405
        else:
Tri Dao's avatar
Tri Dao committed
406
407
            q = qkv
            q = apply_rotary_emb_func(
Tri Dao's avatar
Tri Dao committed
408
409
410
411
412
                q,
                self._cos_cached[seqlen_offset:],
                self._sin_cached[seqlen_offset:],
                self.interleaved,
                True,
Tri Dao's avatar
Tri Dao committed
413
            )
Tri Dao's avatar
Tri Dao committed
414
415
            if self.scale is None:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
416
417
418
419
                    kv,
                    self._cos_cached[seqlen_offset:],
                    self._sin_cached[seqlen_offset:],
                    self.interleaved,
Tri Dao's avatar
Tri Dao committed
420
421
422
                )
            else:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
423
424
425
426
                    kv,
                    self._cos_k_cached[seqlen_offset:],
                    self._sin_k_cached[seqlen_offset:],
                    self.interleaved,
Tri Dao's avatar
Tri Dao committed
427
428
                )
            return q, kv