rotary.py 10.4 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3
4
5
6
7
8
9

from typing import Tuple
import math

import torch

from einops import rearrange, repeat

Tri Dao's avatar
Tri Dao committed
10
import rotary_emb
Tri Dao's avatar
Tri Dao committed
11
12


13
14
15
16
17
18
19
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
Tri Dao's avatar
Tri Dao committed
20
21


22
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
23
24
25
26
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2)
    """
27
28
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
29
30
    cos = repeat(cos, 's d -> s 1 (2 d)')
    sin = repeat(sin, 's d -> s 1 (2 d)')
31
32
    return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
                      x[..., ro_dim:]], dim=-1)
Tri Dao's avatar
Tri Dao committed
33
34
35
36
37


class ApplyRotaryEmb(torch.autograd.Function):

    @staticmethod
38
    def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
Tri Dao's avatar
Tri Dao committed
39
40
41
        """
            x: (batch_size, seqlen, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
42
43
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
        rotary_dim must be <= headdim
        Apply rotary embedding to the first rotary_dim of x.
        """
        batch, seqlen, nheads, headdim = x.shape
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Alexander Ploshkin's avatar
Alexander Ploshkin committed
52
        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
53
54
        x_ro = x[..., :rotary_dim]
        x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
55
        out = torch.empty_like(x) if not inplace else x
56
57
58
59
60
61
        out_ro = out[..., :rotary_dim]
        if inplace:
            o1, o2 = x1, x2
        else:
            o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
                      else (out_ro[..., ::2], out_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
62
63
        rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
Tri Dao's avatar
Tri Dao committed
64
65
66
        if not inplace and rotary_dim < headdim:
            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
        ctx.save_for_backward(cos, sin)
67
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, headdim = do.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        inplace = ctx.inplace
78
79
80
        do_ro = do[..., :rotary_dim]
        do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (do_ro[..., ::2], do_ro[..., 1::2]))
Tri Dao's avatar
Tri Dao committed
81
        dx = torch.empty_like(do) if not inplace else do
82
83
84
85
86
87
        if inplace:
            dx1, dx2 = do1, do2
        else:
            dx_ro = dx[..., :rotary_dim]
            dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
                        else (dx_ro[..., ::2], dx_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
88
89
        rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
Tri Dao's avatar
Tri Dao committed
90
91
        if not inplace and rotary_dim < headdim:
            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
92
        return dx, None, None, None, None
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100


apply_rotary_emb_func = ApplyRotaryEmb.apply


class ApplyRotaryEmbQKV_(torch.autograd.Function):

    @staticmethod
101
    def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
Tri Dao's avatar
Tri Dao committed
102
103
104
        """
            qkv: (batch_size, seqlen, 3, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
105
            cos_k, sin_k: (seqlen, rotary_dim / 2), optional
106
107
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
                1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
        rotary_dim must be <= headdim
        Apply rotary embedding *inplace* to the first rotary_dim of q and k.
        """
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
Tri Dao's avatar
Tri Dao committed
117
118
119
        cos_k = cos if cos_k is None else cos_k
        sin_k = sin if sin_k is None else sin_k
        assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
120
121
        q_ro = qkv[:, :, 0, :, :rotary_dim]
        q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
Alexander Ploshkin's avatar
Alexander Ploshkin committed
122
123
        rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
124
125
        k_ro = qkv[:, :, 1, :, :rotary_dim]
        k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
Tri Dao's avatar
Tri Dao committed
126
127
128
        rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
        ctx.save_for_backward(cos, sin, cos_k, sin_k)
129
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
130
131
132
133
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
134
        cos, sin, cos_k, sin_k = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
135
136
137
        _, seqlen, _, _, headdim = dqkv.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
138
139
140
        dq_ro = dqkv[:, :, 0, :, :rotary_dim]
        dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dq_ro[..., ::2], dq_ro[..., 1::2]))
Alexander Ploshkin's avatar
Alexander Ploshkin committed
141
142
        rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
                                rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
143
144
145
        dk_ro = dqkv[:, :, 1, :, :rotary_dim]
        dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
                    else (dk_ro[..., ::2], dk_ro[..., 1::2]))
Tri Dao's avatar
Tri Dao committed
146
147
        rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
                                rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
148
        return dqkv, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
149
150
151


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166


class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

167
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
168
169
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
170
171
    """

172
    def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None):
Tri Dao's avatar
Tri Dao committed
173
        """
174
175
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
Tri Dao's avatar
Tri Dao committed
176
        """
Tri Dao's avatar
Tri Dao committed
177
        super().__init__()
178
        self.base = float(base)
Tri Dao's avatar
Tri Dao committed
179
        # Generate and save the inverse frequency buffer (non trainable)
Tri Dao's avatar
Tri Dao committed
180
181
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
                                                dtype=torch.float32) / dim))
Tri Dao's avatar
Tri Dao committed
182
        self.register_buffer("inv_freq", inv_freq)
183
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
184
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
185
        scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
186
                 / (1.4 * dim) if scale_base is not None else None)
Tri Dao's avatar
Tri Dao committed
187
        self.register_buffer("scale", scale)
Tri Dao's avatar
Tri Dao committed
188

Tri Dao's avatar
Tri Dao committed
189
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
190
191
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
192
193
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
194

195
    def _update_cos_sin_cache(self, x, seqlen_offset=0):
Tri Dao's avatar
Tri Dao committed
196
197
        """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
        """
198
        seqlen = x.shape[1] + seqlen_offset
Tri Dao's avatar
Tri Dao committed
199
200
        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
Tri Dao's avatar
Tri Dao committed
201
        if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
Tri Dao's avatar
Tri Dao committed
202
            or self._cos_cached.dtype != x.dtype):
Tri Dao's avatar
Tri Dao committed
203
204
            self._seq_len_cached = seqlen
            t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
Tri Dao's avatar
Tri Dao committed
205
206
            # Don't do einsum, it converts fp32 to fp16
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Tri Dao's avatar
Tri Dao committed
207
            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
Tri Dao's avatar
Tri Dao committed
208
209
210
211
212
213
            if self.scale is None:
                self._cos_cached = torch.cos(freqs).to(x.dtype)
                self._sin_cached = torch.sin(freqs).to(x.dtype)
            else:
                power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                          - seqlen // 2) / self.scale_base)
Tri Dao's avatar
Tri Dao committed
214
                scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
                # We want the multiplication by scale to happen in fp32
                self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
Tri Dao's avatar
Tri Dao committed
220

221
222
    def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
223
        qkv: (batch, seqlen, 3, nheads, headdim)
224
225
226
227
        seqlen_offset: can be used in generation where the qkv being passed in is only the last
        token in the batch.
        """
        self._update_cos_sin_cache(qkv, seqlen_offset)
Tri Dao's avatar
Tri Dao committed
228
229
        if self.scale is None:
            return apply_rotary_emb_qkv_(
230
231
                qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
                None, None, self.interleaved
Tri Dao's avatar
Tri Dao committed
232
233
234
235
            )
        else:
            return apply_rotary_emb_qkv_(
                qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
236
237
                self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
                self.interleaved
Tri Dao's avatar
Tri Dao committed
238
            )