rotary.py 18.4 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3

import math
Tri Dao's avatar
Tri Dao committed
4
from typing import Optional, Tuple, Union
Tri Dao's avatar
Tri Dao committed
5
6
7

import torch
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.ops.triton.rotary import apply_rotary
Tri Dao's avatar
Tri Dao committed
9
10


11
12
13
14
15
16
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
Tri Dao's avatar
Tri Dao committed
17
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
Tri Dao's avatar
Tri Dao committed
18
19


20
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
21
22
    """
    x: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
23
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
24
    """
25
26
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
27
28
    cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
Tri Dao's avatar
Tri Dao committed
29
30
31
32
    return torch.cat(
        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
        dim=-1,
    )
Tri Dao's avatar
Tri Dao committed
33
34
35
36


class ApplyRotaryEmb(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
    def forward(
        ctx,
        x,
        cos,
        sin,
        interleaved=False,
        inplace=False,
        seqlen_offsets: Union[int, torch.Tensor] = 0,
Tri Dao's avatar
Tri Dao committed
45
46
        cu_seqlens: Optional[torch.Tensor] = None,
        max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
47
48
    ):
        out = apply_rotary(
Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
54
55
56
            x,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            interleaved=interleaved,
            inplace=inplace,
Tri Dao's avatar
Tri Dao committed
57
        )
Tri Dao's avatar
Tri Dao committed
58
        if isinstance(seqlen_offsets, int):
Tri Dao's avatar
Tri Dao committed
59
            ctx.save_for_backward(cos, sin, cu_seqlens)  # Can't save int with save_for_backward
Tri Dao's avatar
Tri Dao committed
60
61
            ctx.seqlen_offsets = seqlen_offsets
        else:
Tri Dao's avatar
Tri Dao committed
62
            ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
Tri Dao's avatar
Tri Dao committed
63
            ctx.seqlen_offsets = None
64
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
65
        ctx.inplace = inplace
Tri Dao's avatar
Tri Dao committed
66
        ctx.max_seqlen = max_seqlen
Tri Dao's avatar
Tri Dao committed
67
68
69
70
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
Tri Dao's avatar
Tri Dao committed
71
72
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
Tri Dao's avatar
Tri Dao committed
73
            cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
74
        else:
Tri Dao's avatar
Tri Dao committed
75
            cos, sin, cu_seqlens = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
83
84
        # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
        # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
        if not ctx.interleaved and not ctx.inplace:
            do = do.clone()
        dx = apply_rotary(
            do,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
Tri Dao's avatar
Tri Dao committed
85
86
            cu_seqlens=cu_seqlens,
            max_seqlen=ctx.max_seqlen,
Tri Dao's avatar
Tri Dao committed
87
88
89
            interleaved=ctx.interleaved,
            inplace=ctx.inplace,
            conjugate=True,
Tri Dao's avatar
Tri Dao committed
90
        )
Tri Dao's avatar
Tri Dao committed
91
        return dx, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
92
93
94


def apply_rotary_emb(
Tri Dao's avatar
Tri Dao committed
95
96
97
98
99
100
101
102
    x,
    cos,
    sin,
    interleaved=False,
    inplace=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
    cu_seqlens: Optional[torch.Tensor] = None,
    max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
103
104
105
):
    """
    Arguments:
Tri Dao's avatar
Tri Dao committed
106
107
        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
            else (total_seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
        cos, sin: (seqlen_rotary, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        inplace: if True, apply rotary embedding in-place.
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
Tri Dao's avatar
Tri Dao committed
114
115
        cu_seqlens: (batch + 1,) or None
        max_seqlen: int
Tri Dao's avatar
Tri Dao committed
116
    Return:
Tri Dao's avatar
Tri Dao committed
117
118
        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
            else (total_seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
119
120
121
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    """
Tri Dao's avatar
Tri Dao committed
122
123
124
    return ApplyRotaryEmb.apply(
        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
    )
Tri Dao's avatar
Tri Dao committed
125
126


Tri Dao's avatar
Tri Dao committed
127
128
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
Tri Dao's avatar
Tri Dao committed
129
130
131
132


class ApplyRotaryEmbQKV_(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
133
134
135
136
137
138
139
140
141
142
    def forward(
        ctx,
        qkv,
        cos,
        sin,
        cos_k=None,
        sin_k=None,
        interleaved=False,
        seqlen_offsets: Union[int, torch.Tensor] = 0,
    ):
Tri Dao's avatar
Tri Dao committed
143
144
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
Tri Dao's avatar
Tri Dao committed
145
146
147
148
        if cos_k is None and sin_k is None and qkv.is_contiguous():
            # Call 1 kernel instead of 2 kernels
            # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
            # dimensions, we get the same tensor
149
150
            # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
            qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
Tri Dao's avatar
Tri Dao committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            apply_rotary(
                qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
            )
        else:
            cos_k = cos if cos_k is None else cos_k
            sin_k = sin if sin_k is None else sin_k
            q, k = qkv[:, :, 0], qkv[:, :, 1]
            apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
            apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
            ctx.save_for_backward(cos, sin, cos_k, sin_k)
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin, cos_k, sin_k)
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
            ctx.seqlen_offsets = None
167
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
168
169
170
171
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin, cos_k, sin_k = ctx.saved_tensors
        if cos_k is None and sin_k is None and dqkv.is_contiguous():
            # Call 1 kernel instead of 2 kernels
            # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
            # dimensions, we get the same tensor
            dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
            apply_rotary(
                dqk,
                cos,
                sin,
                seqlen_offsets=seqlen_offsets,
                interleaved=ctx.interleaved,
                inplace=True,
                conjugate=True,
            )
        else:
            cos_k = cos if cos_k is None else cos_k
            sin_k = sin if sin_k is None else sin_k
            dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
            apply_rotary(
196
                dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
Tri Dao's avatar
Tri Dao committed
197
198
199
200
201
202
            )
            apply_rotary(
                dk,
                cos_k,
                sin_k,
                seqlen_offsets,
203
                interleaved=ctx.interleaved,
Tri Dao's avatar
Tri Dao committed
204
                inplace=True,
205
                conjugate=True,
Tri Dao's avatar
Tri Dao committed
206
207
            )
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
208
209


Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def apply_rotary_emb_qkv_(
    qkv,
    cos,
    sin,
    cos_k=None,
    sin_k=None,
    interleaved=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        cos_k, sin_k: (seqlen, rotary_dim / 2), optional
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
            1st half and 2nd half (GPT-NeoX style).
        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
    """
    return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
Tri Dao's avatar
Tri Dao committed
234
235


Tri Dao's avatar
Tri Dao committed
236
237
class ApplyRotaryEmbKV_(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
238
    def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
Tri Dao's avatar
Tri Dao committed
239
240
        batch, seqlen, two, nheads, headdim = kv.shape
        assert two == 2
Tri Dao's avatar
Tri Dao committed
241
242
243
244
245
246
247
248
249
250
        k = kv[:, :, 0]
        apply_rotary(
            k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
        )
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin)  # Can't save int with save_for_backward
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, seqlen_offsets)
            ctx.seqlen_offsets = None
Tri Dao's avatar
Tri Dao committed
251
252
253
254
255
        ctx.interleaved = interleaved
        return kv

    @staticmethod
    def backward(ctx, dkv):
Tri Dao's avatar
Tri Dao committed
256
257
258
259
260
261
262
263
264
265
266
267
268
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin = ctx.saved_tensors
        apply_rotary(
            dkv[:, :, 0],
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            interleaved=ctx.interleaved,
            inplace=True,
            conjugate=True,
Tri Dao's avatar
Tri Dao committed
269
        )
Tri Dao's avatar
Tri Dao committed
270
        return dkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
271
272
273
274
275


apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply


Tri Dao's avatar
Tri Dao committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def apply_rotary_emb_kv_(
    kv,
    cos,
    sin,
    interleaved=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
):
    """
    Arguments:
        kv: (batch_size, seqlen, 2, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
            1st half and 2nd half (GPT-NeoX style).
        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        kv: (batch_size, seqlen, 2, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding *inplace* to the first rotary_dim of K.
    """
    return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)


Tri Dao's avatar
Tri Dao committed
299
300
301
302
303
304
305
306
307
308
309
310
311
class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

312
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
313
314
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
315
316
    """

Tri Dao's avatar
Tri Dao committed
317
318
319
320
321
322
323
324
325
    def __init__(
        self,
        dim: int,
        base=10000.0,
        interleaved=False,
        scale_base=None,
        pos_idx_in_fp32=True,
        device=None,
    ):
Tri Dao's avatar
Tri Dao committed
326
        """
Tri Dao's avatar
Tri Dao committed
327
328
329
330
331
332
333
334
335
336
337
338
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
            otherwise they might be in lower precision.
            This option was added because previously (before 2023-07-02), when we construct
            the position indices, we use the dtype of self.inv_freq. In most cases this would
            be fp32, but if the model is trained in pure bf16 (not mixed precision), then
            self.inv_freq would be bf16, and the position indices are also in bf16.
            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
            embeddings for some positions will coincide.
            To maintain compatibility with models previously trained in pure bf16,
            we add this option.
Tri Dao's avatar
Tri Dao committed
339
        """
Tri Dao's avatar
Tri Dao committed
340
        super().__init__()
341
        self.dim = dim
342
        self.base = float(base)
343
        self.pos_idx_in_fp32 = pos_idx_in_fp32
Tri Dao's avatar
Tri Dao committed
344
        # Generate and save the inverse frequency buffer (non trainable)
345
        inv_freq = self._compute_inv_freq(device)
346
        self.register_buffer("inv_freq", inv_freq, persistent=False)
347
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
348
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
349
350
351
352
353
        scale = (
            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
            if scale_base is not None
            else None
        )
354
        self.register_buffer("scale", scale, persistent=False)
Tri Dao's avatar
Tri Dao committed
355

Tri Dao's avatar
Tri Dao committed
356
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
357
358
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
359
360
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
361

362
    def _compute_inv_freq(self, device=None):
Tri Dao's avatar
Tri Dao committed
363
364
365
366
        return 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
        )
367
368

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
369
        # Reset the tables if the sequence length has changed,
370
371
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
Tri Dao's avatar
Tri Dao committed
372
373
        if (
            seqlen > self._seq_len_cached
374
            or self._cos_cached is None
Tri Dao's avatar
Tri Dao committed
375
            or self._cos_cached.device != device
376
            or self._cos_cached.dtype != dtype
Tri Dao's avatar
Tri Dao committed
377
378
            or (self.training and self._cos_cached.is_inference())
        ):
Tri Dao's avatar
Tri Dao committed
379
            self._seq_len_cached = seqlen
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
Tri Dao's avatar
Tri Dao committed
397
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
398
            freqs = torch.outer(t, inv_freq)
Tri Dao's avatar
Tri Dao committed
399
            if self.scale is None:
400
401
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
Tri Dao's avatar
Tri Dao committed
402
            else:
Tri Dao's avatar
Tri Dao committed
403
404
405
406
407
                power = (
                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                    - seqlen // 2
                ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
Tri Dao's avatar
Tri Dao committed
408
                # We want the multiplication by scale to happen in fp32
409
410
411
412
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
Tri Dao's avatar
Tri Dao committed
413

Tri Dao's avatar
Tri Dao committed
414
    def forward(
Tri Dao's avatar
Tri Dao committed
415
416
417
418
419
        self,
        qkv: torch.Tensor,
        kv: Optional[torch.Tensor] = None,
        seqlen_offset: Union[int, torch.Tensor] = 0,
        max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
420
    ) -> Tuple[torch.Tensor, torch.Tensor]:
421
        """
Tri Dao's avatar
Tri Dao committed
422
423
424
        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
             else it's just q of shape (batch, seqlen, nheads, headdim)
        kv: (batch, seqlen, 2, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
425
426
427
428
429
        seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
            If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
            should pass in max_seqlen, which will update the cos / sin cache up to that length.
        Apply rotary embedding *inplace* to qkv and / or kv.
430
        """
Tri Dao's avatar
Tri Dao committed
431
        seqlen = qkv.shape[1]
432
        if max_seqlen is not None:
Tri Dao's avatar
Tri Dao committed
433
            self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
434
435
        elif isinstance(seqlen_offset, int):
            self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
Tri Dao's avatar
Tri Dao committed
436
437
438
        if kv is None:
            if self.scale is None:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
439
                    qkv,
Tri Dao's avatar
Tri Dao committed
440
441
442
443
                    self._cos_cached,
                    self._sin_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
444
445
446
                )
            else:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
447
                    qkv,
Tri Dao's avatar
Tri Dao committed
448
449
450
451
452
453
                    self._cos_cached,
                    self._sin_cached,
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
454
                )
Tri Dao's avatar
Tri Dao committed
455
        else:
Tri Dao's avatar
Tri Dao committed
456
457
            q = qkv
            q = apply_rotary_emb_func(
Tri Dao's avatar
Tri Dao committed
458
                q,
Tri Dao's avatar
Tri Dao committed
459
460
461
462
463
                self._cos_cached,
                self._sin_cached,
                interleaved=self.interleaved,
                inplace=True,
                seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
464
            )
Tri Dao's avatar
Tri Dao committed
465
466
            if self.scale is None:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
467
                    kv,
Tri Dao's avatar
Tri Dao committed
468
469
470
471
                    self._cos_cached,
                    self._sin_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
472
473
474
                )
            else:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
475
                    kv,
Tri Dao's avatar
Tri Dao committed
476
477
478
479
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
480
481
                )
            return q, kv