rope.py 19.2 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""
Rotary Position Embedding implementation of different types along with helper functions
"""
8
from typing import Optional, Tuple, Union, List
9
import torch
10

11
import transformer_engine_torch as tex
12
13
14
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat


15
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        rotary_percent: float = 1.0,
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
        rotary_base: float = 10000.0,
30
        interleaved: bool = False,
31
32
33
34
35
    ):
        """
        Parameters
        ----------
        dim: int
36
37
            Rotary embedding dimension.
        rotary_percent: float, default = 1.0
38
            Percent of rotary dimension to use for rotary position embeddings.
39
40
        seq_len_interpolation_factor: int, default = None
            If not None, discrete positions will be interpolated by this factor via the trick in
41
            https://arxiv.org/abs/2306.15595
42
43
44
45
46
47
        pretrained_max_position_embeddings: int, default = None
            Pre-trained max_position_embeddings before position interpolation.
        rotary_base: float, default = 10000.0
            Base of the rotary position embedding.
        interleaved: bool, default = False
            Whether to use interleaved rotary position embedding.
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        """
        super().__init__()
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
        self.rotary_base = rotary_base
        inv_freq = 1.0 / (
            self.rotary_base
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
        self.register_buffer("inv_freq", inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
63
        self.interleaved = interleaved
64
65
66

    def forward(self, max_seq_len: int, offset: int = 0):
        """
67
        Create rotary position embedding frequencies.
68

69
70
71
        This function is particularly sensitive to the use of mixed precision, so we disable the
        autocast context if it is enabled.

72
73
74
        Parameters
        ----------
        max_seq_len: int
75
            Sequence length of a sample.
76
        offset: int, default = 0
77
            Fixed offset for frequencies.
78
        """
79
80
81
82
83
        with torch.autocast(enabled=False, device_type="cuda"):
            seq = (
                torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
                + offset
            )
84
85

            if (
86
87
                self.pretrained_max_position_embeddings is not None
                and self.seq_len_interpolation_factor is not None
88
            ):
89
90
91
92
93
94
95
96
97
                if (
                    max_seq_len
                    > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
                ):
                    # dynamic linear scaling (length > position we have learned)
                    seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
                else:
                    # fixed linear scaling
                    seq *= 1 / self.seq_len_interpolation_factor
98

99
            freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
100
101
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
102
103
104
105
106
107
        if not self.interleaved:
            emb = torch.cat((freqs, freqs), dim=-1)
        else:
            emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
                freqs.shape[0], -1
            )
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))


class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
Sudhakar Singh's avatar
Sudhakar Singh committed
126
        start_positions: Union[torch.Tensor, None] = None,
127
        tensor_format: str = "sbhd",
128
        interleaved: bool = False,
129
130
131
132
        cu_seqlens: Union[torch.Tensor, None] = None,
        cp_size: int = 1,
        cp_rank: int = 0,
    ) -> torch.Tensor:
133
        """Fused RoPE forward."""
Sudhakar Singh's avatar
Sudhakar Singh committed
134

135
136
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
137
138
139
140
141
142
        assert tensor_format in (
            "sbhd",
            "bshd",
            "thd",
        ), f"Unsupported tensor_format: {tensor_format}."
        output = tex.fused_rope_forward(
Sudhakar Singh's avatar
Sudhakar Singh committed
143
144
145
146
147
148
149
150
            t,
            freqs,
            start_positions,
            QKVFormat[tensor_format],
            interleaved,
            cu_seqlens,
            cp_size,
            cp_rank,
151
        )
152
        ctx.save_for_backward(freqs, cu_seqlens, start_positions)
153
154
155
        ctx.tensor_format = tensor_format
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
156
        ctx.interleaved = interleaved
157
158
159
160
161

        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
162
        """Fused RoPE backward."""
163
        freqs, cu_seqlens, start_positions = ctx.saved_tensors
164
165
166
        grad_input = tex.fused_rope_backward(
            grad_output,
            freqs,
167
            start_positions,
168
169
170
171
172
173
174
            QKVFormat[ctx.tensor_format],
            ctx.interleaved,
            cu_seqlens,
            ctx.cp_size,
            ctx.cp_rank,
        )

175
        return grad_input, None, None, None, None, None, None, None, None
176
177


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class FusedQKVRoPEFunc(torch.autograd.Function):
    """
    Function for FusedQKVRoPE

    This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
    The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
    """

    @staticmethod
    def forward(
        ctx,
        qkv: torch.Tensor,
        q_freqs: torch.Tensor,
        k_freqs: torch.Tensor,
        qkv_split_arg_list: List[int],
        start_positions: Union[torch.Tensor, None] = None,
        tensor_format: str = "sbhd",
        interleaved: bool = False,
        cp_size: int = 1,
        cp_rank: int = 0,
    ) -> torch.Tensor:
        """Fused RoPE forward."""

        if q_freqs.dtype != torch.float32:
            q_freqs = q_freqs.float()
        if k_freqs.dtype != torch.float32:
            k_freqs = k_freqs.float()
        assert tensor_format in (
            "sbhd",
            "bshd",
        ), f"Unsupported tensor_format: {tensor_format}."
        assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
        assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
        assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
        output = tex.fused_qkv_rope_forward(
            qkv,
            q_freqs,
            k_freqs,
            start_positions,
            qkv_split_arg_list,
            QKVFormat[tensor_format],
            interleaved,
            cp_size,
            cp_rank,
        )
        ctx.save_for_backward(q_freqs, k_freqs)
        ctx.tensor_format = tensor_format
        ctx.qkv_split_arg_list = qkv_split_arg_list
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
        ctx.interleaved = interleaved
        return output

    @staticmethod
    def backward(
        ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Fused RoPE backward."""
        q_freqs, k_freqs = ctx.saved_tensors

        grad_output_q = grad_output_q.contiguous()
        grad_output_k = grad_output_k.contiguous()
        grad_output_v = grad_output_v.contiguous()

        grad_input = tex.fused_qkv_rope_backward(
            grad_output_q,
            grad_output_k,
            grad_output_v,
            q_freqs,
            k_freqs,
            ctx.qkv_split_arg_list,
            QKVFormat[ctx.tensor_format],
            ctx.interleaved,
            ctx.cp_size,
            ctx.cp_rank,
        )

        return grad_input, None, None, None, None, None, None, None, None


258
259
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
    """Change sign so the last dimension becomes [-odd, +even]
260

261
262
263
264
265
266
    Args:
        x: torch.Tensor. Input tensor.
        interleaved: bool. Whether to use interleaved rotary position embedding.

    Returns:
        Tensor: Tensor rotated half.
267
    """
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    if not interleaved:
        x1, x2 = torch.chunk(x, 2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    # interleaved
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x_new = torch.stack((-x2, x1), dim=-1)
    return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)


def _apply_rotary_pos_emb_base(
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    interleaved: bool = False,
) -> torch.Tensor:
    """
    Base implementation of applying rotary positional embedding tensor to the input tensor.

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
290
    t : torch.Tensor
291
292
        Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
        embedding will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
293
    freqs : torch.Tensor
294
295
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
        and dtype 'float', with `s2 >= s` and `d2 <= d`.
Paweł Gadziński's avatar
Paweł Gadziński committed
296
    tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
297
298
        Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
        `[seq, bs, ...]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
299
    interleaved : bool, default = False
300
301
        Whether to use interleaved rotary position embedding.
    """
Sudhakar Singh's avatar
Sudhakar Singh committed
302
303
    # [seq, 1, 1, dim] -> [1, seq, 1, dim] or
    # [seq, b, 1, dim] -> [b, seq, 1, dim]
304
    if tensor_format == "bshd":
Sudhakar Singh's avatar
Sudhakar Singh committed
305
        freqs = freqs.transpose(0, 1)
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)

    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
    return torch.cat((t, t_pass), dim=-1)


def _get_freqs_on_this_cp_rank(
    freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
    """Get the position embedding on the current context parallel rank.

    Args:
Paweł Gadziński's avatar
Paweł Gadziński committed
327
        freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`.
328
329
330
        seqlen: int. Length of the current sequence.
        cp_size: int. Context parallel world size.
        cp_rank: int. Context parallel rank.
331
    """
332
333
334
335
336
337
338
339
340
341
342
    if cp_size > 1:
        cp_seg = seqlen // 2
        full_seqlen = cp_size * seqlen
        return torch.cat(
            [
                freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
                freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
            ]
        )

    # cp_size == 1
343
    return freqs[:seqlen]
344
345
346
347
348
349


def apply_rotary_pos_emb(
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
Sudhakar Singh's avatar
Sudhakar Singh committed
350
    start_positions: Union[torch.Tensor, None] = None,
351
    interleaved: bool = False,
352
353
354
355
356
357
358
359
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
    cp_size: int = 1,
    cp_rank: int = 0,
) -> torch.Tensor:
    """
    Apply rotary positional embedding tensor to the input tensor.

Sudhakar Singh's avatar
Sudhakar Singh committed
360
361
362
363
364
    Support matrix:
    Fused/Unfused:
        Training:
            qkv_formats:            "thd", "bshd", "sbhd"
            context parallel:       yes
365
            start_positions:        yes
Sudhakar Singh's avatar
Sudhakar Singh committed
366
367
368
369
370
            interleaving:           yes
        Inference:
            qkv_formats:            "thd", "bshd", "sbhd"
            context parallelism:    no
            start_positions:        yes
371
            interleaving:           yes
Sudhakar Singh's avatar
Sudhakar Singh committed
372

373
374
    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
375
    t : torch.Tensor
376
377
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
        rotary positional embedding will be applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
378
    freqs : torch.Tensor
379
380
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Paweł Gadziński's avatar
Paweł Gadziński committed
381
    start_positions : torch.Tensor, default = None.
Sudhakar Singh's avatar
Sudhakar Singh committed
382
383
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
Paweł Gadziński's avatar
Paweł Gadziński committed
384
    tensor_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
385
386
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
Paweł Gadziński's avatar
Paweł Gadziński committed
387
    interleaved : bool, default = False
388
        Whether to use interleaved rotary position embedding.
Paweł Gadziński's avatar
Paweł Gadziński committed
389
    fused : bool, default = False
390
        Whether to use a fused applying RoPE implementation.
Paweł Gadziński's avatar
Paweł Gadziński committed
391
    cu_seqlens : torch.Tensor, default = None.
392
393
394
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
        Should be `cu_seqlens_padded` when cp_size > 1.
Paweł Gadziński's avatar
Paweł Gadziński committed
395
    cp_size : int, default = 1.
396
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
Paweł Gadziński's avatar
Paweł Gadziński committed
397
    cp_rank : int, default = 0.
398
399
400
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
    """
    assert (
401
402
        tensor_format != "thd" or cu_seqlens is not None
    ), "cu_seqlens must not be None when tensor_format is 'thd'."
403

404
    # Fused apply rope logic for THD/BSHD/SBHD formats
405
406
    if fused:
        return FusedRoPEFunc.apply(
Sudhakar Singh's avatar
Sudhakar Singh committed
407
            t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
408
        )
409

410
    # Unfused apply rope logic for THD format
411
412
413
    if tensor_format == "thd":
        cu_seqlens = cu_seqlens // cp_size
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Sudhakar Singh's avatar
Sudhakar Singh committed
414
415
416
417
418

        # The following code essentially splits the `thd` tensor into corresponding
        # `s1hd` tensors (for each sequence) and applies rotary embedding to
        # those sequences individually.
        # Note that if `start_positions` is not `None`, then for each sequence,
419
        # the freqs supplied are offset by the corresponding `start_positions` value.
420
421
422
423
        return torch.cat(
            [
                _apply_rotary_pos_emb_base(
                    x.unsqueeze(1),
424
425
426
427
428
429
430
                    _get_freqs_on_this_cp_rank(
                        (
                            freqs[start_positions[idx] :] if start_positions is not None else freqs
                        ),  # offset the freqs
                        x.size(0),
                        cp_size,
                        cp_rank,
Sudhakar Singh's avatar
Sudhakar Singh committed
431
                    ),
432
433
                    interleaved=interleaved,
                )
Sudhakar Singh's avatar
Sudhakar Singh committed
434
                for idx, x in enumerate(torch.split(t, seqlens))
435
436
437
            ]
        ).squeeze(1)

438
439
    # Unfused apply rope logic for SBHD/BSHD format follows ...

440
441
442
443
444
445
    if tensor_format == "sbhd":
        seqlen = t.size(0)
    elif tensor_format == "bshd":
        seqlen = t.size(1)
    else:
        raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
446
447
448
449
450
451
452
453
454
455
456

    if start_positions is not None:
        max_offset = torch.max(start_positions)
        assert (
            max_offset + seqlen * cp_size <= freqs.shape[0]
        ), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!"

        # Stack staggered rope embeddings along the batch dimension
        freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1)
        # Note that from this point, `freqs` has a shape `(s,b,1,d)`.

457
458
459
460
461
462
    return _apply_rotary_pos_emb_base(
        t,
        _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
        tensor_format,
        interleaved=interleaved,
    )
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490


def apply_fused_qkv_rotary_pos_emb(
    qkv: torch.Tensor,
    q_freqs: torch.Tensor,
    k_freqs: torch.Tensor,
    qkv_split_arg_list: List[int],
    tensor_format: str = "sbhd",
    start_positions: Union[torch.Tensor, None] = None,
    interleaved: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,  # pylint: disable=unused-argument
    cp_size: int = 1,
    cp_rank: int = 0,
) -> torch.Tensor:
    """
    Apply rotary positional embedding tensor to the input qkv tensor.

    Support matrix:
    Fused:
        Training:
            qkv_formats:            "bshd", "sbhd"
            context parallel:       yes
            start_positions:        no
            interleaving:           yes
        Inference:
            qkv_formats:            "bshd", "sbhd"
            context parallelism:    no
            start_positions:        yes
491
            interleaving:           yes
492
493
494

    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
495
    qkv : torch.Tensor
496
497
498
        Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
        rotary positional embedding will be applied. This tensor has q, k, v concatenated
        along the last dimension.
Paweł Gadziński's avatar
Paweł Gadziński committed
499
    q_freqs : torch.Tensor
500
501
        Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Paweł Gadziński's avatar
Paweł Gadziński committed
502
    k_freqs : torch.Tensor
503
504
        Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Paweł Gadziński's avatar
Paweł Gadziński committed
505
    qkv_split_arg_list : List[int]
506
507
508
509
        List of integers that specify the split of the qkv tensor. The list should have 3 elements,
        the first element is the number of elements in the q tensor, the second element is the number
        of elements in the k tensor, and the third element is the number of elements in the v tensor.
        The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
510
    start_positions : torch.Tensor, default = None.
511
512
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
Paweł Gadziński's avatar
Paweł Gadziński committed
513
    tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
514
515
        is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
        of shape `[seq, bs, ...]`.
Paweł Gadziński's avatar
Paweł Gadziński committed
516
    interleaved : bool, default = False
517
        Whether to use interleaved rotary position embedding.
Paweł Gadziński's avatar
Paweł Gadziński committed
518
    cp_size : int, default = 1.
519
        Context parallel world size.
Paweł Gadziński's avatar
Paweł Gadziński committed
520
    cp_rank : int, default = 0.
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        Context parallel rank.
    """

    # `start_positions` is only supported for `cp_size=1` and inference.
    assert not (
        cp_size > 1 and start_positions is not None
    ), """start_positions != None with CP SIZE > 1 is not supported!"""

    assert tensor_format != "thd", "'thd' tensor_format not supported currently."

    return FusedQKVRoPEFunc.apply(
        qkv,
        q_freqs,
        k_freqs,
        qkv_split_arg_list,
        start_positions,
        tensor_format,
        interleaved,
        cp_size,
        cp_rank,
    )