rope.py 19.7 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Rotary Position Embedding implementation of different types along with helper functions
"""
8
from typing import Optional, Tuple, Union, List
9
import torch
10

11
import transformer_engine_torch as tex
12
13
14
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat


15
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        rotary_percent: float = 1.0,
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
        rotary_base: float = 10000.0,
30
        interleaved: bool = False,
31
32
33
34
35
    ):
        """
        Parameters
        ----------
        dim: int
36
37
            Rotary embedding dimension.
        rotary_percent: float, default = 1.0
38
            Percent of rotary dimension to use for rotary position embeddings.
39
40
        seq_len_interpolation_factor: int, default = None
            If not None, discrete positions will be interpolated by this factor via the trick in
41
            https://arxiv.org/abs/2306.15595
42
43
44
45
46
47
        pretrained_max_position_embeddings: int, default = None
            Pre-trained max_position_embeddings before position interpolation.
        rotary_base: float, default = 10000.0
            Base of the rotary position embedding.
        interleaved: bool, default = False
            Whether to use interleaved rotary position embedding.
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        """
        super().__init__()
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
        self.rotary_base = rotary_base
        inv_freq = 1.0 / (
            self.rotary_base
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
        self.register_buffer("inv_freq", inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
63
        self.interleaved = interleaved
64
65
66

    def forward(self, max_seq_len: int, offset: int = 0):
        """
67
        Create rotary position embedding frequencies.
68
69
70
71

        Parameters
        ----------
        max_seq_len: int
72
            Sequence length of a sample.
73
        offset: int, default = 0
74
            Fixed offset for frequencies.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        """
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )

        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
98
99
100
101
102
103
        if not self.interleaved:
            emb = torch.cat((freqs, freqs), dim=-1)
        else:
            emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
                freqs.shape[0], -1
            )
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))


class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
Sudhakar Singh's avatar
Sudhakar Singh committed
122
        start_positions: Union[torch.Tensor, None] = None,
123
        tensor_format: str = "sbhd",
124
        interleaved: bool = False,
125
126
127
128
        cu_seqlens: Union[torch.Tensor, None] = None,
        cp_size: int = 1,
        cp_rank: int = 0,
    ) -> torch.Tensor:
129
        """Fused RoPE forward."""
Sudhakar Singh's avatar
Sudhakar Singh committed
130

131
132
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
133
134
135
136
137
138
        assert tensor_format in (
            "sbhd",
            "bshd",
            "thd",
        ), f"Unsupported tensor_format: {tensor_format}."
        output = tex.fused_rope_forward(
Sudhakar Singh's avatar
Sudhakar Singh committed
139
140
141
142
143
144
145
146
            t,
            freqs,
            start_positions,
            QKVFormat[tensor_format],
            interleaved,
            cu_seqlens,
            cp_size,
            cp_rank,
147
        )
148
149
150
151
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
152
        ctx.interleaved = interleaved
153
154
155
156
157

        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
158
        """Fused RoPE backward."""
159
        freqs, cu_seqlens = ctx.saved_tensors
160
161
162
163
164
165
166
167
168
169
        grad_input = tex.fused_rope_backward(
            grad_output,
            freqs,
            QKVFormat[ctx.tensor_format],
            ctx.interleaved,
            cu_seqlens,
            ctx.cp_size,
            ctx.cp_rank,
        )

Sudhakar Singh's avatar
Sudhakar Singh committed
170
        return grad_input, None, None, None, None, None, None, None
171
172


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class FusedQKVRoPEFunc(torch.autograd.Function):
    """
    Function for FusedQKVRoPE

    This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
    The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
    """

    @staticmethod
    def forward(
        ctx,
        qkv: torch.Tensor,
        q_freqs: torch.Tensor,
        k_freqs: torch.Tensor,
        qkv_split_arg_list: List[int],
        start_positions: Union[torch.Tensor, None] = None,
        tensor_format: str = "sbhd",
        interleaved: bool = False,
        cp_size: int = 1,
        cp_rank: int = 0,
    ) -> torch.Tensor:
        """Fused RoPE forward."""

        if q_freqs.dtype != torch.float32:
            q_freqs = q_freqs.float()
        if k_freqs.dtype != torch.float32:
            k_freqs = k_freqs.float()
        assert tensor_format in (
            "sbhd",
            "bshd",
        ), f"Unsupported tensor_format: {tensor_format}."
        assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
        assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
        assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
        output = tex.fused_qkv_rope_forward(
            qkv,
            q_freqs,
            k_freqs,
            start_positions,
            qkv_split_arg_list,
            QKVFormat[tensor_format],
            interleaved,
            cp_size,
            cp_rank,
        )
        ctx.save_for_backward(q_freqs, k_freqs)
        ctx.tensor_format = tensor_format
        ctx.qkv_split_arg_list = qkv_split_arg_list
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
        ctx.interleaved = interleaved
        return output

    @staticmethod
    def backward(
        ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Fused RoPE backward."""
        q_freqs, k_freqs = ctx.saved_tensors

        grad_output_q = grad_output_q.contiguous()
        grad_output_k = grad_output_k.contiguous()
        grad_output_v = grad_output_v.contiguous()

        grad_input = tex.fused_qkv_rope_backward(
            grad_output_q,
            grad_output_k,
            grad_output_v,
            q_freqs,
            k_freqs,
            ctx.qkv_split_arg_list,
            QKVFormat[ctx.tensor_format],
            ctx.interleaved,
            ctx.cp_size,
            ctx.cp_rank,
        )

        return grad_input, None, None, None, None, None, None, None, None


253
254
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
    """Change sign so the last dimension becomes [-odd, +even]
255

256
257
258
259
260
261
    Args:
        x: torch.Tensor. Input tensor.
        interleaved: bool. Whether to use interleaved rotary position embedding.

    Returns:
        Tensor: Tensor rotated half.
262
    """
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    if not interleaved:
        x1, x2 = torch.chunk(x, 2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    # interleaved
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x_new = torch.stack((-x2, x1), dim=-1)
    return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)


def _apply_rotary_pos_emb_base(
    t: torch.Tensor,
    freqs: torch.Tensor,
Sudhakar Singh's avatar
Sudhakar Singh committed
277
    start_positions: torch.Tensor = None,
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    tensor_format: str = "sbhd",
    interleaved: bool = False,
) -> torch.Tensor:
    """
    Base implementation of applying rotary positional embedding tensor to the input tensor.

    Parameters
    ----------
    t: torch.Tensor
        Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
        embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Sudhakar Singh's avatar
Sudhakar Singh committed
292
293
294
    start_positions: torch.Tensor, default = None.
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
295
296
297
298
299
300
301
302
303
    tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
        Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
        `[seq, bs, ...]`.
    interleaved: bool, default = False
        Whether to use interleaved rotary position embedding.
    """
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

Sudhakar Singh's avatar
Sudhakar Singh committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    # In case `start_positions` are provided, create a staggered `freqs` tensor
    # offset by the values in `start_positions`.
    # `start_positions` is only supported for `cp_size=1` and inference.
    if start_positions is not None:
        max_offset = torch.max(start_positions)
        assert (
            max_offset + cur_seq_len <= max_seq_len
        ), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!"

        # Stack staggered rope embeddings along the batch dimension
        freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)

        # Note that from this point, `freqs` has a shape `(s,b,1,d)`.

318
319
320
321
322
323
    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    freqs = freqs[:cur_seq_len]
Sudhakar Singh's avatar
Sudhakar Singh committed
324
325
326

    # [seq, 1, 1, dim] -> [1, seq, 1, dim] or
    # [seq, b, 1, dim] -> [b, seq, 1, dim]
327
    if tensor_format == "bshd":
Sudhakar Singh's avatar
Sudhakar Singh committed
328
        freqs = freqs.transpose(0, 1)
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)

    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
    return torch.cat((t, t_pass), dim=-1)


def _get_freqs_on_this_cp_rank(
    freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
    """Get the position embedding on the current context parallel rank.

    Args:
        freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
        seqlen: int. Length of the current sequence.
        cp_size: int. Context parallel world size.
        cp_rank: int. Context parallel rank.
353
    """
354
355
356
357
358
359
360
361
362
363
364
    if cp_size > 1:
        cp_seg = seqlen // 2
        full_seqlen = cp_size * seqlen
        return torch.cat(
            [
                freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
                freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
            ]
        )

    # cp_size == 1
Sudhakar Singh's avatar
Sudhakar Singh committed
365
    return freqs
366
367
368
369
370
371


def apply_rotary_pos_emb(
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
Sudhakar Singh's avatar
Sudhakar Singh committed
372
    start_positions: Union[torch.Tensor, None] = None,
373
    interleaved: bool = False,
374
375
376
377
378
379
380
381
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
    cp_size: int = 1,
    cp_rank: int = 0,
) -> torch.Tensor:
    """
    Apply rotary positional embedding tensor to the input tensor.

Sudhakar Singh's avatar
Sudhakar Singh committed
382
383
384
385
386
387
388
389
390
391
392
    Support matrix:
    Fused/Unfused:
        Training:
            qkv_formats:            "thd", "bshd", "sbhd"
            context parallel:       yes
            start_positions:        no
            interleaving:           yes
        Inference:
            qkv_formats:            "thd", "bshd", "sbhd"
            context parallelism:    no
            start_positions:        yes
393
            interleaving:            yes
Sudhakar Singh's avatar
Sudhakar Singh committed
394

395
396
397
398
399
400
401
402
    Parameters
    ----------
    t: torch.Tensor
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Sudhakar Singh's avatar
Sudhakar Singh committed
403
404
405
    start_positions: torch.Tensor, default = None.
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
406
407
408
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
409
410
411
412
    interleaved: bool, default = False
        Whether to use interleaved rotary position embedding.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
413
414
415
416
417
418
419
420
421
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
        Should be `cu_seqlens_padded` when cp_size > 1.
    cp_size: int, default = 1.
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
    cp_rank: int, default = 0.
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
    """
Sudhakar Singh's avatar
Sudhakar Singh committed
422
423
424
425
426
427

    # `start_positions` is only supported for `cp_size=1` and inference.
    assert not (
        cp_size > 1 and start_positions is not None
    ), """start_positions != None with CP SIZE > 1 is not supported!"""

428
    assert (
429
430
        tensor_format != "thd" or cu_seqlens is not None
    ), "cu_seqlens must not be None when tensor_format is 'thd'."
431

432
433
    if fused:
        return FusedRoPEFunc.apply(
Sudhakar Singh's avatar
Sudhakar Singh committed
434
            t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
435
        )
436

437
438
439
440
    # Unfused THD format
    if tensor_format == "thd":
        cu_seqlens = cu_seqlens // cp_size
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Sudhakar Singh's avatar
Sudhakar Singh committed
441
442
443
444
445
446
447

        # The following code essentially splits the `thd` tensor into corresponding
        # `s1hd` tensors (for each sequence) and applies rotary embedding to
        # those sequences individually.
        # Note that if `start_positions` is not `None`, then for each sequence,
        # it's corresponding rope offset is also supplied from `start_positions`
        # individually.
448
449
450
451
452
        return torch.cat(
            [
                _apply_rotary_pos_emb_base(
                    x.unsqueeze(1),
                    _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
Sudhakar Singh's avatar
Sudhakar Singh committed
453
454
455
                    start_positions=(
                        start_positions[idx : idx + 1] if start_positions is not None else None
                    ),
456
457
                    interleaved=interleaved,
                )
Sudhakar Singh's avatar
Sudhakar Singh committed
458
                for idx, x in enumerate(torch.split(t, seqlens))
459
460
461
462
463
464
465
466
467
468
469
470
471
            ]
        ).squeeze(1)

    # Unfused SBHD/BSHD format
    if tensor_format == "sbhd":
        seqlen = t.size(0)
    elif tensor_format == "bshd":
        seqlen = t.size(1)
    else:
        raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
    return _apply_rotary_pos_emb_base(
        t,
        _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
Sudhakar Singh's avatar
Sudhakar Singh committed
472
        start_positions,
473
474
475
        tensor_format,
        interleaved=interleaved,
    )
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554


def apply_fused_qkv_rotary_pos_emb(
    qkv: torch.Tensor,
    q_freqs: torch.Tensor,
    k_freqs: torch.Tensor,
    qkv_split_arg_list: List[int],
    tensor_format: str = "sbhd",
    start_positions: Union[torch.Tensor, None] = None,
    interleaved: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,  # pylint: disable=unused-argument
    cp_size: int = 1,
    cp_rank: int = 0,
) -> torch.Tensor:
    """
    Apply rotary positional embedding tensor to the input qkv tensor.

    Support matrix:
    Fused:
        Training:
            qkv_formats:            "bshd", "sbhd"
            context parallel:       yes
            start_positions:        no
            interleaving:           yes
        Inference:
            qkv_formats:            "bshd", "sbhd"
            context parallelism:    no
            start_positions:        yes
            interleaving:            yes

    Parameters
    ----------
    qkv: torch.Tensor
        Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
        rotary positional embedding will be applied. This tensor has q, k, v concatenated
        along the last dimension.
    q_freqs: torch.Tensor
        Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    k_freqs: torch.Tensor
        Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    qkv_split_arg_list: List[int]
        List of integers that specify the split of the qkv tensor. The list should have 3 elements,
        the first element is the number of elements in the q tensor, the second element is the number
        of elements in the k tensor, and the third element is the number of elements in the v tensor.
        The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
    start_positions: torch.Tensor, default = None.
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
    tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
        is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
        of shape `[seq, bs, ...]`.
    interleaved: bool, default = False
        Whether to use interleaved rotary position embedding.
    cp_size: int, default = 1.
        Context parallel world size.
    cp_rank: int, default = 0.
        Context parallel rank.
    """

    # `start_positions` is only supported for `cp_size=1` and inference.
    assert not (
        cp_size > 1 and start_positions is not None
    ), """start_positions != None with CP SIZE > 1 is not supported!"""

    assert tensor_format != "thd", "'thd' tensor_format not supported currently."

    return FusedQKVRoPEFunc.apply(
        qkv,
        q_freqs,
        k_freqs,
        qkv_split_arg_list,
        start_positions,
        tensor_format,
        interleaved,
        cp_size,
        cp_rank,
    )