rotary.py 12.4 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3
4
5
6
7
8
9

from typing import Tuple
import math

import torch

from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
10
import rotary_emb
Tri Dao's avatar
Tri Dao committed
11
12


13
14
15
16
17
18
19
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
Tri Dao's avatar
Tri Dao committed
20
21


22
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
23
24
25
26
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2)
    """
27
28
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
29
30
    cos = repeat(cos, 's d -> s 1 (2 d)')
    sin = repeat(sin, 's d -> s 1 (2 d)')
31
32
    return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
                      x[..., ro_dim:]], dim=-1)
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37


class ApplyRotaryEmb(torch.autograd.Function):

    @staticmethod
38
    def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
Tri Dao's avatar
Tri Dao committed
39
40
41
        """
            x: (batch_size, seqlen, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
42
43
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
        rotary_dim must be <= headdim
        Apply rotary embedding to the first rotary_dim of x.
        """
        batch, seqlen, nheads, headdim = x.shape
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Alexander Ploshkin's avatar
Alexander Ploshkin committed
52
        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
53
54
        x_ro = x[..., :rotary_dim]
        x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
55
        out = torch.empty_like(x) if not inplace else x
56
57
58
59
60
61
        out_ro = out[..., :rotary_dim]
        if inplace:
            o1, o2 = x1, x2
        else:
            o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
                      else (out_ro[..., ::2], out_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
62
63
        rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
Tri Dao's avatar
Tri Dao committed
64
65
66
        if not inplace and rotary_dim < headdim:
            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
        ctx.save_for_backward(cos, sin)
67
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, headdim = do.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        inplace = ctx.inplace
78
79
80
        do_ro = do[..., :rotary_dim]
        do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (do_ro[..., ::2], do_ro[..., 1::2]))
Tri Dao's avatar
Tri Dao committed
81
        dx = torch.empty_like(do) if not inplace else do
82
83
84
85
86
87
        if inplace:
            dx1, dx2 = do1, do2
        else:
            dx_ro = dx[..., :rotary_dim]
            dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
                        else (dx_ro[..., ::2], dx_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
88
89
        rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
Tri Dao's avatar
Tri Dao committed
90
91
        if not inplace and rotary_dim < headdim:
            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
92
        return dx, None, None, None, None
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100


apply_rotary_emb_func = ApplyRotaryEmb.apply


class ApplyRotaryEmbQKV_(torch.autograd.Function):

    @staticmethod
101
    def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
Tri Dao's avatar
Tri Dao committed
102
103
104
        """
            qkv: (batch_size, seqlen, 3, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
105
            cos_k, sin_k: (seqlen, rotary_dim / 2), optional
106
107
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
                1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of q and k.
        """
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Tri Dao's avatar
Tri Dao committed
117
118
119
        cos_k = cos if cos_k is None else cos_k
        sin_k = sin if sin_k is None else sin_k
        assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
120
121
        q_ro = qkv[:, :, 0, :, :rotary_dim]
        q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
Alexander Ploshkin's avatar
Alexander Ploshkin committed
122
123
        rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
124
125
        k_ro = qkv[:, :, 1, :, :rotary_dim]
        k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
126
127
128
        rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
        ctx.save_for_backward(cos, sin, cos_k, sin_k)
129
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
130
131
132
133
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
134
        cos, sin, cos_k, sin_k = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
135
136
137
        _, seqlen, _, _, headdim = dqkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
138
139
140
        dq_ro = dqkv[:, :, 0, :, :rotary_dim]
        dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dq_ro[..., ::2], dq_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
141
142
        rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
143
144
145
        dk_ro = dqkv[:, :, 1, :, :rotary_dim]
        dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dk_ro[..., ::2], dk_ro[..., 1::2]))
Tri Dao's avatar
Tri Dao committed
146
147
        rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
148
        return dqkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
149
150
151


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166


class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

167
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
168
169
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
170
171
    """

172
173
    def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
                 pos_idx_in_fp32=True, device=None):
Tri Dao's avatar
Tri Dao committed
174
        """
175
176
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
177
178
179
180
181
182
183
184
185
186
            pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
                otherwise they might be in lower precision.
                This option was added because previously (before 2023-07-02), when we construct
                the position indices, we use the dtype of self.inv_freq. In most cases this would
                be fp32, but if the model is trained in pure bf16 (not mixed precision), then
                self.inv_freq would be bf16, and the position indices are also in bf16.
                Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
                embeddings for some positions will coincide.
                To maintain compatibility with models previously trained in pure bf16,
                we add this option.
Tri Dao's avatar
Tri Dao committed
187
        """
Tri Dao's avatar
Tri Dao committed
188
        super().__init__()
189
        self.dim = dim
190
        self.base = float(base)
191
        self.pos_idx_in_fp32 = pos_idx_in_fp32
Tri Dao's avatar
Tri Dao committed
192
        # Generate and save the inverse frequency buffer (non trainable)
193
        inv_freq = self._compute_inv_freq(device)
Tri Dao's avatar
Tri Dao committed
194
        self.register_buffer("inv_freq", inv_freq)
195
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
196
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
197
        scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
198
                 / (1.4 * dim) if scale_base is not None else None)
Tri Dao's avatar
Tri Dao committed
199
        self.register_buffer("scale", scale)
Tri Dao's avatar
Tri Dao committed
200

Tri Dao's avatar
Tri Dao committed
201
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
202
203
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
204
205
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
206

207
208
209
210
211
212
    def _compute_inv_freq(self, device=None):
        return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
                                                 dtype=torch.float32) / self.dim))


    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
213
        # Reset the tables if the sequence length has changed,
214
215
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
216
        if (seqlen > self._seq_len_cached or self._cos_cached.device != device
217
218
            or self._cos_cached.dtype != dtype
            or (self.training and self._cos_cached.is_inference())):
Tri Dao's avatar
Tri Dao committed
219
            self._seq_len_cached = seqlen
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
Tri Dao's avatar
Tri Dao committed
237
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
238
            freqs = torch.outer(t, inv_freq)
Tri Dao's avatar
Tri Dao committed
239
            if self.scale is None:
240
241
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
Tri Dao's avatar
Tri Dao committed
242
243
244
            else:
                power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                          - seqlen // 2) / self.scale_base)
Tri Dao's avatar
Tri Dao committed
245
                scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
Tri Dao's avatar
Tri Dao committed
246
                # We want the multiplication by scale to happen in fp32
247
248
249
250
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
Tri Dao's avatar
Tri Dao committed
251

252
253
    def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
254
        qkv: (batch, seqlen, 3, nheads, headdim)
255
256
257
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
258
        self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
Tri Dao's avatar
Tri Dao committed
259
260
        if self.scale is None:
            return apply_rotary_emb_qkv_(
261
262
                qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                None, None, self.interleaved
Tri Dao's avatar
Tri Dao committed
263
264
265
266
            )
        else:
            return apply_rotary_emb_qkv_(
                qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
267
268
                self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
                self.interleaved
Tri Dao's avatar
Tri Dao committed
269
            )