rotary.py 18.3 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3

import math
Tri Dao's avatar
Tri Dao committed
4
from typing import Optional, Tuple, Union
Tri Dao's avatar
Tri Dao committed
5
6
7

import torch
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.ops.triton.rotary import apply_rotary
Tri Dao's avatar
Tri Dao committed
9
10


11
12
13
14
15
16
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
Tri Dao's avatar
Tri Dao committed
17
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
Tri Dao's avatar
Tri Dao committed
18
19


20
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
Tri Dao's avatar
Tri Dao committed
21
22
    """
    x: (batch_size, seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
23
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
Tri Dao's avatar
Tri Dao committed
24
    """
25
26
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
Tri Dao's avatar
Tri Dao committed
27
28
    cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
Tri Dao's avatar
Tri Dao committed
29
30
31
32
    return torch.cat(
        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
        dim=-1,
    )
Tri Dao's avatar
Tri Dao committed
33
34
35
36


class ApplyRotaryEmb(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
    def forward(
        ctx,
        x,
        cos,
        sin,
        interleaved=False,
        inplace=False,
        seqlen_offsets: Union[int, torch.Tensor] = 0,
Tri Dao's avatar
Tri Dao committed
45
46
        cu_seqlens: Optional[torch.Tensor] = None,
        max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
47
48
    ):
        out = apply_rotary(
Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
54
55
56
            x,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            interleaved=interleaved,
            inplace=inplace,
Tri Dao's avatar
Tri Dao committed
57
        )
Tri Dao's avatar
Tri Dao committed
58
        if isinstance(seqlen_offsets, int):
Tri Dao's avatar
Tri Dao committed
59
            ctx.save_for_backward(cos, sin, cu_seqlens)  # Can't save int with save_for_backward
Tri Dao's avatar
Tri Dao committed
60
61
            ctx.seqlen_offsets = seqlen_offsets
        else:
Tri Dao's avatar
Tri Dao committed
62
            ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
Tri Dao's avatar
Tri Dao committed
63
            ctx.seqlen_offsets = None
64
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
65
        ctx.inplace = inplace
Tri Dao's avatar
Tri Dao committed
66
        ctx.max_seqlen = max_seqlen
Tri Dao's avatar
Tri Dao committed
67
68
69
70
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
Tri Dao's avatar
Tri Dao committed
71
72
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
Tri Dao's avatar
Tri Dao committed
73
            cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
74
        else:
Tri Dao's avatar
Tri Dao committed
75
            cos, sin, cu_seqlens = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
83
84
        # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
        # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
        if not ctx.interleaved and not ctx.inplace:
            do = do.clone()
        dx = apply_rotary(
            do,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
Tri Dao's avatar
Tri Dao committed
85
86
            cu_seqlens=cu_seqlens,
            max_seqlen=ctx.max_seqlen,
Tri Dao's avatar
Tri Dao committed
87
88
89
            interleaved=ctx.interleaved,
            inplace=ctx.inplace,
            conjugate=True,
Tri Dao's avatar
Tri Dao committed
90
        )
Tri Dao's avatar
Tri Dao committed
91
        return dx, None, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
92
93
94


def apply_rotary_emb(
Tri Dao's avatar
Tri Dao committed
95
96
97
98
99
100
101
102
    x,
    cos,
    sin,
    interleaved=False,
    inplace=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
    cu_seqlens: Optional[torch.Tensor] = None,
    max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
103
104
105
):
    """
    Arguments:
Tri Dao's avatar
Tri Dao committed
106
107
        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
            else (total_seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
        cos, sin: (seqlen_rotary, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        inplace: if True, apply rotary embedding in-place.
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
Tri Dao's avatar
Tri Dao committed
114
115
        cu_seqlens: (batch + 1,) or None
        max_seqlen: int
Tri Dao's avatar
Tri Dao committed
116
    Return:
Tri Dao's avatar
Tri Dao committed
117
118
        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
            else (total_seqlen, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
119
120
121
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    """
Tri Dao's avatar
Tri Dao committed
122
123
124
    return ApplyRotaryEmb.apply(
        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
    )
Tri Dao's avatar
Tri Dao committed
125
126


Tri Dao's avatar
Tri Dao committed
127
128
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
Tri Dao's avatar
Tri Dao committed
129
130
131
132


class ApplyRotaryEmbQKV_(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
133
134
135
136
137
138
139
140
141
142
    def forward(
        ctx,
        qkv,
        cos,
        sin,
        cos_k=None,
        sin_k=None,
        interleaved=False,
        seqlen_offsets: Union[int, torch.Tensor] = 0,
    ):
Tri Dao's avatar
Tri Dao committed
143
144
        batch, seqlen, three, nheads, headdim = qkv.shape
        assert three == 3
Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        if cos_k is None and sin_k is None and qkv.is_contiguous():
            # Call 1 kernel instead of 2 kernels
            # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
            # dimensions, we get the same tensor
            qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
            apply_rotary(
                qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
            )
        else:
            cos_k = cos if cos_k is None else cos_k
            sin_k = sin if sin_k is None else sin_k
            q, k = qkv[:, :, 0], qkv[:, :, 1]
            apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
            apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
            ctx.save_for_backward(cos, sin, cos_k, sin_k)
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin, cos_k, sin_k)
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
            ctx.seqlen_offsets = None
166
        ctx.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
167
168
169
170
        return qkv

    @staticmethod
    def backward(ctx, dqkv):
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin, cos_k, sin_k = ctx.saved_tensors
        if cos_k is None and sin_k is None and dqkv.is_contiguous():
            # Call 1 kernel instead of 2 kernels
            # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
            # dimensions, we get the same tensor
            dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
            apply_rotary(
                dqk,
                cos,
                sin,
                seqlen_offsets=seqlen_offsets,
                interleaved=ctx.interleaved,
                inplace=True,
                conjugate=True,
            )
        else:
            cos_k = cos if cos_k is None else cos_k
            sin_k = sin if sin_k is None else sin_k
            dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
            apply_rotary(
                dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True
            )
            apply_rotary(
                dk,
                cos_k,
                sin_k,
                seqlen_offsets,
                interleaved=interleaved,
                inplace=True,
                conjudate=True,
            )
        return dqkv, None, None, None, None, None, None
Tri Dao's avatar
Tri Dao committed
207
208


Tri Dao's avatar
Tri Dao committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def apply_rotary_emb_qkv_(
    qkv,
    cos,
    sin,
    cos_k=None,
    sin_k=None,
    interleaved=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        cos_k, sin_k: (seqlen, rotary_dim / 2), optional
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
            1st half and 2nd half (GPT-NeoX style).
        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
    """
    return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
Tri Dao's avatar
Tri Dao committed
233
234


Tri Dao's avatar
Tri Dao committed
235
236
class ApplyRotaryEmbKV_(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
237
    def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
Tri Dao's avatar
Tri Dao committed
238
239
        batch, seqlen, two, nheads, headdim = kv.shape
        assert two == 2
Tri Dao's avatar
Tri Dao committed
240
241
242
243
244
245
246
247
248
249
        k = kv[:, :, 0]
        apply_rotary(
            k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
        )
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin)  # Can't save int with save_for_backward
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, seqlen_offsets)
            ctx.seqlen_offsets = None
Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
        ctx.interleaved = interleaved
        return kv

    @staticmethod
    def backward(ctx, dkv):
Tri Dao's avatar
Tri Dao committed
255
256
257
258
259
260
261
262
263
264
265
266
267
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin = ctx.saved_tensors
        apply_rotary(
            dkv[:, :, 0],
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            interleaved=ctx.interleaved,
            inplace=True,
            conjugate=True,
Tri Dao's avatar
Tri Dao committed
268
        )
Tri Dao's avatar
Tri Dao committed
269
        return dkv, None, None, None, None
Tri Dao's avatar
Tri Dao committed
270
271
272
273
274


apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply


Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def apply_rotary_emb_kv_(
    kv,
    cos,
    sin,
    interleaved=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
):
    """
    Arguments:
        kv: (batch_size, seqlen, 2, nheads, headdim)
        cos, sin: (seqlen, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
            1st half and 2nd half (GPT-NeoX style).
        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
            Most commonly used in inference when we have KV cache.
    Return:
        kv: (batch_size, seqlen, 2, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding *inplace* to the first rotary_dim of K.
    """
    return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)


Tri Dao's avatar
Tri Dao committed
298
299
300
301
302
303
304
305
306
307
308
309
310
class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.

    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration

    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

311
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
312
313
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Tri Dao's avatar
Tri Dao committed
314
315
    """

Tri Dao's avatar
Tri Dao committed
316
317
318
319
320
321
322
323
324
    def __init__(
        self,
        dim: int,
        base=10000.0,
        interleaved=False,
        scale_base=None,
        pos_idx_in_fp32=True,
        device=None,
    ):
Tri Dao's avatar
Tri Dao committed
325
        """
Tri Dao's avatar
Tri Dao committed
326
327
328
329
330
331
332
333
334
335
336
337
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
            otherwise they might be in lower precision.
            This option was added because previously (before 2023-07-02), when we construct
            the position indices, we use the dtype of self.inv_freq. In most cases this would
            be fp32, but if the model is trained in pure bf16 (not mixed precision), then
            self.inv_freq would be bf16, and the position indices are also in bf16.
            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
            embeddings for some positions will coincide.
            To maintain compatibility with models previously trained in pure bf16,
            we add this option.
Tri Dao's avatar
Tri Dao committed
338
        """
Tri Dao's avatar
Tri Dao committed
339
        super().__init__()
340
        self.dim = dim
341
        self.base = float(base)
342
        self.pos_idx_in_fp32 = pos_idx_in_fp32
Tri Dao's avatar
Tri Dao committed
343
        # Generate and save the inverse frequency buffer (non trainable)
344
        inv_freq = self._compute_inv_freq(device)
345
        self.register_buffer("inv_freq", inv_freq, persistent=False)
346
        self.interleaved = interleaved
Tri Dao's avatar
Tri Dao committed
347
        self.scale_base = scale_base
Tri Dao's avatar
Tri Dao committed
348
349
350
351
352
        scale = (
            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
            if scale_base is not None
            else None
        )
353
        self.register_buffer("scale", scale, persistent=False)
Tri Dao's avatar
Tri Dao committed
354

Tri Dao's avatar
Tri Dao committed
355
        self._seq_len_cached = 0
Tri Dao's avatar
Tri Dao committed
356
357
        self._cos_cached = None
        self._sin_cached = None
Tri Dao's avatar
Tri Dao committed
358
359
        self._cos_k_cached = None
        self._sin_k_cached = None
Tri Dao's avatar
Tri Dao committed
360

361
    def _compute_inv_freq(self, device=None):
Tri Dao's avatar
Tri Dao committed
362
363
364
365
        return 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
        )
366
367

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
368
        # Reset the tables if the sequence length has changed,
369
370
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
Tri Dao's avatar
Tri Dao committed
371
372
373
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
374
            or self._cos_cached.dtype != dtype
Tri Dao's avatar
Tri Dao committed
375
376
            or (self.training and self._cos_cached.is_inference())
        ):
Tri Dao's avatar
Tri Dao committed
377
            self._seq_len_cached = seqlen
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
Tri Dao's avatar
Tri Dao committed
395
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
396
            freqs = torch.outer(t, inv_freq)
Tri Dao's avatar
Tri Dao committed
397
            if self.scale is None:
398
399
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
Tri Dao's avatar
Tri Dao committed
400
            else:
Tri Dao's avatar
Tri Dao committed
401
402
403
404
405
                power = (
                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                    - seqlen // 2
                ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
Tri Dao's avatar
Tri Dao committed
406
                # We want the multiplication by scale to happen in fp32
407
408
409
410
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
Tri Dao's avatar
Tri Dao committed
411

Tri Dao's avatar
Tri Dao committed
412
    def forward(
Tri Dao's avatar
Tri Dao committed
413
414
415
416
417
        self,
        qkv: torch.Tensor,
        kv: Optional[torch.Tensor] = None,
        seqlen_offset: Union[int, torch.Tensor] = 0,
        max_seqlen: Optional[int] = None,
Tri Dao's avatar
Tri Dao committed
418
    ) -> Tuple[torch.Tensor, torch.Tensor]:
419
        """
Tri Dao's avatar
Tri Dao committed
420
421
422
        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
             else it's just q of shape (batch, seqlen, nheads, headdim)
        kv: (batch, seqlen, 2, nheads, headdim)
Tri Dao's avatar
Tri Dao committed
423
424
425
426
427
        seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
            If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
            should pass in max_seqlen, which will update the cos / sin cache up to that length.
        Apply rotary embedding *inplace* to qkv and / or kv.
428
        """
Tri Dao's avatar
Tri Dao committed
429
        seqlen = qkv.shape[1]
430
        if max_seqlen is not None:
Tri Dao's avatar
Tri Dao committed
431
            self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
432
433
        elif isinstance(seqlen_offset, int):
            self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
Tri Dao's avatar
Tri Dao committed
434
435
436
        if kv is None:
            if self.scale is None:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
437
                    qkv,
Tri Dao's avatar
Tri Dao committed
438
439
440
441
                    self._cos_cached,
                    self._sin_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
442
443
444
                )
            else:
                return apply_rotary_emb_qkv_(
Tri Dao's avatar
Tri Dao committed
445
                    qkv,
Tri Dao's avatar
Tri Dao committed
446
447
448
449
450
451
                    self._cos_cached,
                    self._sin_cached,
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
452
                )
Tri Dao's avatar
Tri Dao committed
453
        else:
Tri Dao's avatar
Tri Dao committed
454
455
            q = qkv
            q = apply_rotary_emb_func(
Tri Dao's avatar
Tri Dao committed
456
                q,
Tri Dao's avatar
Tri Dao committed
457
458
459
460
461
                self._cos_cached,
                self._sin_cached,
                interleaved=self.interleaved,
                inplace=True,
                seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
462
            )
Tri Dao's avatar
Tri Dao committed
463
464
            if self.scale is None:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
465
                    kv,
Tri Dao's avatar
Tri Dao committed
466
467
468
469
                    self._cos_cached,
                    self._sin_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
470
471
472
                )
            else:
                kv = apply_rotary_emb_kv_(
Tri Dao's avatar
Tri Dao committed
473
                    kv,
Tri Dao's avatar
Tri Dao committed
474
475
476
477
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=seqlen_offset,
Tri Dao's avatar
Tri Dao committed
478
479
                )
            return q, kv