rotary.py 17.6 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3

import math
Tri Dao's avatar
Tri Dao committed
4
from typing import Optional, Tuple, Union
Tri Dao's avatar
Tri Dao committed
5
6
7

import torch
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.ops.triton.rotary import apply_rotary
Tri Dao's avatar
Tri Dao committed
9
10


11
12
13
14
15
16
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
Tri Dao's avatar
Tri Dao committed
17
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
Tri Dao's avatar
Tri Dao committed
18
19


20
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
21
22
    """
    x: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
23
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
24
    """
25
26
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
27
28
    cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
Tri Dao's avatar
Tri Dao committed
29
30
31
32
    return torch.cat(
        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
        dim=-1,
    )
Tri Dao's avatar
Tri Dao committed
33
34
35
36


class ApplyRotaryEmb(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
45
46
47
    def forward(
        ctx,
        x,
        cos,
        sin,
        interleaved=False,
        inplace=False,
        seqlen_offsets: Union[int, torch.Tensor] = 0,
    ):
        out = apply_rotary(
            x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
Tri Dao's avatar
Tri Dao committed
48
        )
Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
54
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin)  # Can't save int with save_for_backward
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, seqlen_offsets)
            ctx.seqlen_offsets = None
55
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
Tri Dao's avatar
Tri Dao committed
61
62
63
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, seqlen_offsets = ctx.saved_tensors
64
        else:
Tri Dao's avatar
Tri Dao committed
65
66
67
68
69
70
71
72
73
74
75
76
77
            cos, sin = ctx.saved_tensors
        # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
        # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
        if not ctx.interleaved and not ctx.inplace:
            do = do.clone()
        dx = apply_rotary(
            do,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            interleaved=ctx.interleaved,
            inplace=ctx.inplace,
            conjugate=True,
Tri Dao's avatar
Tri Dao committed
78
        )
Tri Dao's avatar
Tri Dao committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        return dx, None, None, None, None, None


def apply_rotary_emb(
    x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0
):
    """
    Arguments:
        x: (batch_size, seqlen, nheads, headdim)
        cos, sin: (seqlen_rotary, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        inplace: if True, apply rotary embedding in-place.
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        out: (batch_size, seqlen, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    """
    return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets)
Tri Dao's avatar
Tri Dao committed
100
101


Tri Dao's avatar
Tri Dao committed
102
103
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
Tri Dao's avatar
Tri Dao committed
104
105
106
107


class ApplyRotaryEmbQKV_(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
117
    def forward(
        ctx,
        qkv,
        cos,
        sin,
        cos_k=None,
        sin_k=None,
        interleaved=False,
        seqlen_offsets: Union[int, torch.Tensor] = 0,
    ):
Tri Dao's avatar
Tri Dao committed
118
119
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
Tri Dao's avatar
Tri Dao committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        if cos_k is None and sin_k is None and qkv.is_contiguous():
            # Call 1 kernel instead of 2 kernels
            # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
            # dimensions, we get the same tensor
            qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
            apply_rotary(
                qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
            )
        else:
            cos_k = cos if cos_k is None else cos_k
            sin_k = sin if sin_k is None else sin_k
            q, k = qkv[:, :, 0], qkv[:, :, 1]
            apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
            apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
            ctx.save_for_backward(cos, sin, cos_k, sin_k)
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin, cos_k, sin_k)
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
            ctx.seqlen_offsets = None
141
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
142
143
144
145
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin, cos_k, sin_k = ctx.saved_tensors
        if cos_k is None and sin_k is None and dqkv.is_contiguous():
            # Call 1 kernel instead of 2 kernels
            # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
            # dimensions, we get the same tensor
            dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
            apply_rotary(
                dqk,
                cos,
                sin,
                seqlen_offsets=seqlen_offsets,
                interleaved=ctx.interleaved,
                inplace=True,
                conjugate=True,
            )
        else:
            cos_k = cos if cos_k is None else cos_k
            sin_k = sin if sin_k is None else sin_k
            dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
            apply_rotary(
                dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True
            )
            apply_rotary(
                dk,
                cos_k,
                sin_k,
                seqlen_offsets,
                interleaved=interleaved,
                inplace=True,
                conjudate=True,
            )
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
182
183


Tri Dao's avatar
Tri Dao committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def apply_rotary_emb_qkv_(
    qkv,
    cos,
    sin,
    cos_k=None,
    sin_k=None,
    interleaved=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        cos_k, sin_k: (seqlen, rotary_dim / 2), optional
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
            1st half and 2nd half (GPT-NeoX style).
        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
    """
    return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
Tri Dao's avatar
Tri Dao committed
208
209


Tri Dao's avatar
Tri Dao committed
210
211
class ApplyRotaryEmbKV_(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
212
    def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
Tri Dao's avatar
Tri Dao committed
213
214
        batch, seqlen, two, nheads, headdim = kv.shape
        assert two == 2
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
220
221
222
223
224
        k = kv[:, :, 0]
        apply_rotary(
            k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
        )
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin)  # Can't save int with save_for_backward
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, seqlen_offsets)
            ctx.seqlen_offsets = None
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
        ctx.interleaved = interleaved
        return kv

    @staticmethod
    def backward(ctx, dkv):
Tri Dao's avatar
Tri Dao committed
230
231
232
233
234
235
236
237
238
239
240
241
242
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin = ctx.saved_tensors
        apply_rotary(
            dkv[:, :, 0],
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            interleaved=ctx.interleaved,
            inplace=True,
            conjugate=True,
Tri Dao's avatar
Tri Dao committed
243
        )
Tri Dao's avatar
Tri Dao committed
244
        return dkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
245
246
247
248
249


apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply


Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def apply_rotary_emb_kv_(
    kv,
    cos,
    sin,
    interleaved=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
):
    """
    Arguments:
        kv: (batch_size, seqlen, 2, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
            1st half and 2nd half (GPT-NeoX style).
        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        kv: (batch_size, seqlen, 2, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding *inplace* to the first rotary_dim of K.
    """
    return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)


Tri Dao's avatar
Tri Dao committed
273
274
275
276
277
278
279
280
281
282
283
284
285
class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

286
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
287
288
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
289
290
    """

Tri Dao's avatar
Tri Dao committed
291
292
293
294
295
296
297
298
299
    def __init__(
        self,
        dim: int,
        base=10000.0,
        interleaved=False,
        scale_base=None,
        pos_idx_in_fp32=True,
        device=None,
    ):
Tri Dao's avatar
Tri Dao committed
300
        """
Tri Dao's avatar
Tri Dao committed
301
302
303
304
305
306
307
308
309
310
311
312
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
            otherwise they might be in lower precision.
            This option was added because previously (before 2023-07-02), when we construct
            the position indices, we use the dtype of self.inv_freq. In most cases this would
            be fp32, but if the model is trained in pure bf16 (not mixed precision), then
            self.inv_freq would be bf16, and the position indices are also in bf16.
            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
            embeddings for some positions will coincide.
            To maintain compatibility with models previously trained in pure bf16,
            we add this option.
Tri Dao's avatar
Tri Dao committed
313
        """
Tri Dao's avatar
Tri Dao committed
314
        super().__init__()
315
        self.dim = dim
316
        self.base = float(base)
317
        self.pos_idx_in_fp32 = pos_idx_in_fp32
Tri Dao's avatar
Tri Dao committed
318
        # Generate and save the inverse frequency buffer (non trainable)
319
        inv_freq = self._compute_inv_freq(device)
320
        self.register_buffer("inv_freq", inv_freq, persistent=False)
321
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
322
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
323
324
325
326
327
        scale = (
            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
            if scale_base is not None
            else None
        )
328
        self.register_buffer("scale", scale, persistent=False)
Tri Dao's avatar
Tri Dao committed
329

Tri Dao's avatar
Tri Dao committed
330
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
331
332
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
333
334
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
335

336
    def _compute_inv_freq(self, device=None):
Tri Dao's avatar
Tri Dao committed
337
338
339
340
        return 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
        )
341
342

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
343
        # Reset the tables if the sequence length has changed,
344
345
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
Tri Dao's avatar
Tri Dao committed
346
347
348
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
349
            or self._cos_cached.dtype != dtype
Tri Dao's avatar
Tri Dao committed
350
351
            or (self.training and self._cos_cached.is_inference())
        ):
Tri Dao's avatar
Tri Dao committed
352
            self._seq_len_cached = seqlen
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
Tri Dao's avatar
Tri Dao committed
370
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
371
            freqs = torch.outer(t, inv_freq)
Tri Dao's avatar
Tri Dao committed
372
            if self.scale is None:
373
374
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
Tri Dao's avatar
Tri Dao committed
375
            else:
Tri Dao's avatar
Tri Dao committed
376
377
378
379
380
                power = (
                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                    - seqlen // 2
                ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
Tri Dao's avatar
Tri Dao committed
381
                # We want the multiplication by scale to happen in fp32
382
383
384
385
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
Tri Dao's avatar
Tri Dao committed
386

Tri Dao's avatar
Tri Dao committed
387
    def forward(
Tri Dao's avatar
Tri Dao committed
388
389
390
391
392
        self,
        qkv: torch.Tensor,
        kv: Optional[torch.Tensor] = None,
        seqlen_offset: Union[int, torch.Tensor] = 0,
        max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
393
    ) -> Tuple[torch.Tensor, torch.Tensor]:
394
        """
Tri Dao's avatar
Tri Dao committed
395
396
397
        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
             else it's just q of shape (batch, seqlen, nheads, headdim)
        kv: (batch, seqlen, 2, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
        seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
            If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
            should pass in max_seqlen, which will update the cos / sin cache up to that length.
        Apply rotary embedding *inplace* to qkv and / or kv.
403
        """
Tri Dao's avatar
Tri Dao committed
404
        seqlen = qkv.shape[1]
405
        if max_seqlen is not None:
Tri Dao's avatar
Tri Dao committed
406
            self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
407
408
        elif isinstance(seqlen_offset, int):
            self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
Tri Dao's avatar
Tri Dao committed
409
410
411
        if kv is None:
            if self.scale is None:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
412
                    qkv,
Tri Dao's avatar
Tri Dao committed
413
414
415
416
                    self._cos_cached,
                    self._sin_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
417
418
419
                )
            else:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
420
                    qkv,
Tri Dao's avatar
Tri Dao committed
421
422
423
424
425
426
                    self._cos_cached,
                    self._sin_cached,
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
427
                )
Tri Dao's avatar
Tri Dao committed
428
        else:
Tri Dao's avatar
Tri Dao committed
429
430
            q = qkv
            q = apply_rotary_emb_func(
Tri Dao's avatar
Tri Dao committed
431
                q,
Tri Dao's avatar
Tri Dao committed
432
433
434
435
436
                self._cos_cached,
                self._sin_cached,
                interleaved=self.interleaved,
                inplace=True,
                seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
437
            )
Tri Dao's avatar
Tri Dao committed
438
439
            if self.scale is None:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
440
                    kv,
Tri Dao's avatar
Tri Dao committed
441
442
443
444
                    self._cos_cached,
                    self._sin_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
445
446
447
                )
            else:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
448
                    kv,
Tri Dao's avatar
Tri Dao committed
449
450
451
452
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
453
454
                )
            return q, kv