rope.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union
import torch
10

11
import transformer_engine_torch as tex
12
13
14
15
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat


__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class RotaryPositionEmbedding(torch.nn.Module):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        rotary_percent: float = 1.0,
        seq_len_interpolation_factor: Optional[int] = None,
        pretrained_max_position_embeddings: Optional[int] = None,
        rotary_base: float = 10000.0,
30
        interleaved: bool = False,
31
32
33
34
35
    ):
        """
        Parameters
        ----------
        dim: int
36
37
            Rotary embedding dimension.
        rotary_percent: float, default = 1.0
38
            Percent of rotary dimension to use for rotary position embeddings.
39
40
        seq_len_interpolation_factor: int, default = None
            If not None, discrete positions will be interpolated by this factor via the trick in
41
            https://arxiv.org/abs/2306.15595
42
43
44
45
46
47
        pretrained_max_position_embeddings: int, default = None
            Pre-trained max_position_embeddings before position interpolation.
        rotary_base: float, default = 10000.0
            Base of the rotary position embedding.
        interleaved: bool, default = False
            Whether to use interleaved rotary position embedding.
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        """
        super().__init__()
        if rotary_percent < 1.0:
            dim = int(dim * rotary_percent)
        self.seq_len_interpolation_factor = seq_len_interpolation_factor
        self.rotary_base = rotary_base
        inv_freq = 1.0 / (
            self.rotary_base
            ** (
                torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
                / dim
            )
        )
        self.register_buffer("inv_freq", inv_freq)
        self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
63
        self.interleaved = interleaved
64
65
66

    def forward(self, max_seq_len: int, offset: int = 0):
        """
67
        Create rotary position embedding frequencies.
68
69
70
71

        Parameters
        ----------
        max_seq_len: int
72
            Sequence length of a sample.
73
        offset: int, default = 0
74
            Fixed offset for frequencies.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        """
        seq = (
            torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
            + offset
        )

        if (
            self.pretrained_max_position_embeddings is not None
            and self.seq_len_interpolation_factor is not None
        ):
            if (
                max_seq_len
                > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
            ):
                # dynamic linear scaling (length > position we have learned)
                seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
            else:
                # fixed linear scaling
                seq *= 1 / self.seq_len_interpolation_factor

        freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
        # first part even vector components, second part odd vector components,
        #  2 * dim in dimension size
98
99
100
101
102
103
        if not self.interleaved:
            emb = torch.cat((freqs, freqs), dim=-1)
        else:
            emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
                freqs.shape[0], -1
            )
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        # emb [seq_length, .., dim]
        return emb.reshape(emb.size(0), 1, 1, emb.size(1))


class FusedRoPEFunc(torch.autograd.Function):
    """
    Function for FusedRoPE

    This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
    the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
    the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
    """

    @staticmethod
    def forward(
        ctx,
        t: torch.Tensor,
        freqs: torch.Tensor,
Sudhakar Singh's avatar
Sudhakar Singh committed
122
        start_positions: Union[torch.Tensor, None] = None,
123
        tensor_format: str = "sbhd",
124
        interleaved: bool = False,
125
126
127
128
        cu_seqlens: Union[torch.Tensor, None] = None,
        cp_size: int = 1,
        cp_rank: int = 0,
    ) -> torch.Tensor:
129
        """Fused RoPE forward."""
Sudhakar Singh's avatar
Sudhakar Singh committed
130

131
132
        if freqs.dtype != torch.float32:
            freqs = freqs.float()
133
134
135
136
137
138
        assert tensor_format in (
            "sbhd",
            "bshd",
            "thd",
        ), f"Unsupported tensor_format: {tensor_format}."
        output = tex.fused_rope_forward(
Sudhakar Singh's avatar
Sudhakar Singh committed
139
140
141
142
143
144
145
146
            t,
            freqs,
            start_positions,
            QKVFormat[tensor_format],
            interleaved,
            cu_seqlens,
            cp_size,
            cp_rank,
147
        )
148
149
150
151
        ctx.save_for_backward(freqs, cu_seqlens)
        ctx.tensor_format = tensor_format
        ctx.cp_size = cp_size
        ctx.cp_rank = cp_rank
152
        ctx.interleaved = interleaved
153
154
155
156
157

        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
158
        """Fused RoPE backward."""
159
        freqs, cu_seqlens = ctx.saved_tensors
160
161
162
163
164
165
166
167
168
169
        grad_input = tex.fused_rope_backward(
            grad_output,
            freqs,
            QKVFormat[ctx.tensor_format],
            ctx.interleaved,
            cu_seqlens,
            ctx.cp_size,
            ctx.cp_rank,
        )

Sudhakar Singh's avatar
Sudhakar Singh committed
170
        return grad_input, None, None, None, None, None, None, None
171
172


173
174
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
    """Change sign so the last dimension becomes [-odd, +even]
175

176
177
178
179
180
181
    Args:
        x: torch.Tensor. Input tensor.
        interleaved: bool. Whether to use interleaved rotary position embedding.

    Returns:
        Tensor: Tensor rotated half.
182
    """
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    if not interleaved:
        x1, x2 = torch.chunk(x, 2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    # interleaved
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x_new = torch.stack((-x2, x1), dim=-1)
    return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)


def _apply_rotary_pos_emb_base(
    t: torch.Tensor,
    freqs: torch.Tensor,
Sudhakar Singh's avatar
Sudhakar Singh committed
197
    start_positions: torch.Tensor = None,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    tensor_format: str = "sbhd",
    interleaved: bool = False,
) -> torch.Tensor:
    """
    Base implementation of applying rotary positional embedding tensor to the input tensor.

    Parameters
    ----------
    t: torch.Tensor
        Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
        embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Sudhakar Singh's avatar
Sudhakar Singh committed
212
213
214
    start_positions: torch.Tensor, default = None.
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
215
216
217
218
219
220
221
222
223
    tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
        Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
        `[seq, bs, ...]`.
    interleaved: bool, default = False
        Whether to use interleaved rotary position embedding.
    """
    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

Sudhakar Singh's avatar
Sudhakar Singh committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    # In case `start_positions` are provided, create a staggered `freqs` tensor
    # offset by the values in `start_positions`.
    # `start_positions` is only supported for `cp_size=1` and inference.
    if start_positions is not None:
        max_offset = torch.max(start_positions)
        assert (
            max_offset + cur_seq_len <= max_seq_len
        ), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!"

        # Stack staggered rope embeddings along the batch dimension
        freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)

        # Note that from this point, `freqs` has a shape `(s,b,1,d)`.

238
239
240
241
242
243
    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
    assert (
        cur_seq_len <= max_seq_len
    ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    freqs = freqs[:cur_seq_len]
Sudhakar Singh's avatar
Sudhakar Singh committed
244
245
246

    # [seq, 1, 1, dim] -> [1, seq, 1, dim] or
    # [seq, b, 1, dim] -> [b, seq, 1, dim]
247
    if tensor_format == "bshd":
Sudhakar Singh's avatar
Sudhakar Singh committed
248
        freqs = freqs.transpose(0, 1)
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)

    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
    return torch.cat((t, t_pass), dim=-1)


def _get_freqs_on_this_cp_rank(
    freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
    """Get the position embedding on the current context parallel rank.

    Args:
        freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
        seqlen: int. Length of the current sequence.
        cp_size: int. Context parallel world size.
        cp_rank: int. Context parallel rank.
273
    """
274
275
276
277
278
279
280
281
282
283
284
    if cp_size > 1:
        cp_seg = seqlen // 2
        full_seqlen = cp_size * seqlen
        return torch.cat(
            [
                freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
                freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
            ]
        )

    # cp_size == 1
Sudhakar Singh's avatar
Sudhakar Singh committed
285
    return freqs
286
287
288
289
290
291


def apply_rotary_pos_emb(
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
Sudhakar Singh's avatar
Sudhakar Singh committed
292
    start_positions: Union[torch.Tensor, None] = None,
293
    interleaved: bool = False,
294
295
296
297
298
299
300
301
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
    cp_size: int = 1,
    cp_rank: int = 0,
) -> torch.Tensor:
    """
    Apply rotary positional embedding tensor to the input tensor.

Sudhakar Singh's avatar
Sudhakar Singh committed
302
303
304
305
306
307
308
309
310
311
312
    Support matrix:
    Fused/Unfused:
        Training:
            qkv_formats:            "thd", "bshd", "sbhd"
            context parallel:       yes
            start_positions:        no
            interleaving:           yes
        Inference:
            qkv_formats:            "thd", "bshd", "sbhd"
            context parallelism:    no
            start_positions:        yes
313
            interleaving:            yes
Sudhakar Singh's avatar
Sudhakar Singh committed
314

315
316
317
318
319
320
321
322
    Parameters
    ----------
    t: torch.Tensor
        Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
Sudhakar Singh's avatar
Sudhakar Singh committed
323
324
325
    start_positions: torch.Tensor, default = None.
        Tokens in a sequence `i` should be applied with position encoding offset by
        `start_positions[i]`. If `start_positions=None`, there's no offset.
326
327
328
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
329
330
331
332
    interleaved: bool, default = False
        Whether to use interleaved rotary position embedding.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
333
334
335
336
337
338
339
340
341
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
        Should be `cu_seqlens_padded` when cp_size > 1.
    cp_size: int, default = 1.
        Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
    cp_rank: int, default = 0.
        Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
    """
Sudhakar Singh's avatar
Sudhakar Singh committed
342
343
344
345
346
347

    # `start_positions` is only supported for `cp_size=1` and inference.
    assert not (
        cp_size > 1 and start_positions is not None
    ), """start_positions != None with CP SIZE > 1 is not supported!"""

348
    assert (
349
350
        tensor_format != "thd" or cu_seqlens is not None
    ), "cu_seqlens must not be None when tensor_format is 'thd'."
351

352
353
    if fused:
        return FusedRoPEFunc.apply(
Sudhakar Singh's avatar
Sudhakar Singh committed
354
            t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
355
        )
356

357
358
359
360
    # Unfused THD format
    if tensor_format == "thd":
        cu_seqlens = cu_seqlens // cp_size
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Sudhakar Singh's avatar
Sudhakar Singh committed
361
362
363
364
365
366
367

        # The following code essentially splits the `thd` tensor into corresponding
        # `s1hd` tensors (for each sequence) and applies rotary embedding to
        # those sequences individually.
        # Note that if `start_positions` is not `None`, then for each sequence,
        # it's corresponding rope offset is also supplied from `start_positions`
        # individually.
368
369
370
371
372
        return torch.cat(
            [
                _apply_rotary_pos_emb_base(
                    x.unsqueeze(1),
                    _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
Sudhakar Singh's avatar
Sudhakar Singh committed
373
374
375
                    start_positions=(
                        start_positions[idx : idx + 1] if start_positions is not None else None
                    ),
376
377
                    interleaved=interleaved,
                )
Sudhakar Singh's avatar
Sudhakar Singh committed
378
                for idx, x in enumerate(torch.split(t, seqlens))
379
380
381
382
383
384
385
386
387
388
389
390
391
            ]
        ).squeeze(1)

    # Unfused SBHD/BSHD format
    if tensor_format == "sbhd":
        seqlen = t.size(0)
    elif tensor_format == "bshd":
        seqlen = t.size(1)
    else:
        raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
    return _apply_rotary_pos_emb_base(
        t,
        _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
Sudhakar Singh's avatar
Sudhakar Singh committed
392
        start_positions,
393
394
395
        tensor_format,
        interleaved=interleaved,
    )