test_fused_rope.py 7.93 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import math
5
6
import pytest
import torch
7
from typing import Callable, Tuple, Union
8
from transformer_engine.pytorch.dot_product_attention.rope import (
9
10
11
12
13
14
15
16
17
    RotaryPositionEmbedding,
    apply_rotary_pos_emb,
)


# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
    return output.sum() * 2

18

19
20
21
22
23
24
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
    t = torch.ones_like(output)
    return torch.sum(output * t)


Sudhakar Singh's avatar
Sudhakar Singh committed
25
@pytest.mark.parametrize("start_positions", [True, False])
26
27
28
29
30
31
32
33
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
34
35
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
36
37
38
39
40
41
42
43
44
def test_fused_rope(
    dtype: torch.dtype,
    seq_length: int,
    hidden_size: int,
    rotary_percent: float,
    margin: int,
    transpose: Union[Tuple, None],
    tensor_format: str,
    loss_func: Callable,
45
46
    cp_size: int,
    interleaved: bool,
Sudhakar Singh's avatar
Sudhakar Singh committed
47
    start_positions: bool,
48
) -> None:
Sudhakar Singh's avatar
Sudhakar Singh committed
49
50
51
52
53
54
55
56
57
    if margin == 0 and start_positions == True:
        # This makes sure that the `start_positions` offsets being applied
        # are with the maximum length of the rope embeddings.
        pytest.skip("Skipping test with margin=0 and start_positions=True")

    if start_positions == True and cp_size > 1:
        # `start_positions` is only supported for `cp_size=1` and inference.
        pytest.skip("Skipping test with cp_size>1 and start_positions=True")

58
59
60
61
62
63
64
    device = torch.device("cuda:0")
    batch_size, head_num = 2, 64
    t = torch.rand(
        (seq_length - margin, batch_size, head_num, hidden_size),
        dtype=dtype,
        device=device,
    )
Sudhakar Singh's avatar
Sudhakar Singh committed
65
66
67
68
69
70
71
72

    # Get arbitrary offsets to be used with RoPE for all the sequences
    start_positions = (
        torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
        if start_positions
        else None
    )

73
74
75
76
77
78
    if tensor_format == "bshd":
        t = t.transpose(0, 1).contiguous()
    if transpose:
        t = t.transpose(*transpose).contiguous().transpose(*transpose)
    t.requires_grad = True

79
80
81
    rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
    emb = rotary_pos_emb(seq_length * cp_size)
    assert emb.is_contiguous()
82

83
84
85
86
87
88
89
90
    for cp_rank in range(cp_size):
        # unfused
        # The fused kernel computes in float32 internally, so we force the unfused func to use float32
        # for more accurate comparison
        output_unfused = apply_rotary_pos_emb(
            t.float(),
            emb,
            tensor_format=tensor_format,
Sudhakar Singh's avatar
Sudhakar Singh committed
91
            start_positions=start_positions,
92
93
94
95
96
97
            interleaved=interleaved,
            fused=False,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)
        loss_unfused = loss_func(output_unfused)
Sudhakar Singh's avatar
Sudhakar Singh committed
98
99
100
101
102

        if not isinstance(start_positions, torch.Tensor):
            loss_unfused.backward()
            grad_unfused = t.grad.detach().clone()

103
104
105
106
107
108
109
        t.grad = None

        # fused
        output_fused = apply_rotary_pos_emb(
            t,
            emb,
            tensor_format=tensor_format,
Sudhakar Singh's avatar
Sudhakar Singh committed
110
            start_positions=start_positions,
111
112
113
114
115
116
            interleaved=interleaved,
            fused=True,
            cp_size=cp_size,
            cp_rank=cp_rank,
        )
        loss_fused = loss_func(output_fused)
Sudhakar Singh's avatar
Sudhakar Singh committed
117
118
119
120

        if not isinstance(start_positions, torch.Tensor):
            loss_fused.backward()
            grad_fused = t.grad.detach().clone()
121
122
123
        t.grad = None

        torch.testing.assert_close(output_fused, output_unfused)
Sudhakar Singh's avatar
Sudhakar Singh committed
124
125
126
127

        if not isinstance(start_positions, torch.Tensor):
            torch.testing.assert_close(grad_fused, grad_unfused)

128
        assert output_fused.is_contiguous()
129
130


Sudhakar Singh's avatar
Sudhakar Singh committed
131
132
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
133
134
135
136
137
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
138
139
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
140
141
142
143
144
145
def test_fused_rope_thd(
    dtype: torch.dtype,
    hidden_size: int,
    rotary_percent: float,
    transpose: Union[Tuple, None],
    loss_func: Callable,
146
    cp_size: int,
147
    interleaved: bool,
Sudhakar Singh's avatar
Sudhakar Singh committed
148
149
    start_positions: bool,
    margin: int,
150
) -> None:
Sudhakar Singh's avatar
Sudhakar Singh committed
151
152
153
154
155

    if start_positions == True and cp_size > 1:
        # `start_positions` is only supported for `cp_size=1` and inference.
        pytest.skip("Skipping test with cp_size>1 and start_positions=True")

156
157
    device = torch.device("cuda:0")
    batch_size, head_num = 2, 64
158
    cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
Sudhakar Singh's avatar
Sudhakar Singh committed
159
160
161
162
163
164
165
166

    # Get arbitrary offsets to be used with RoPE for all the sequences
    start_positions = (
        torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
        if start_positions
        else None
    )

167
168
169
170
171
172
173
174
175
176
177
    if cp_size > 1:
        cu_seqlens_padded = [0]
        for i in range(1, len(cu_seqlens)):
            cu_seqlens_padded.append(
                cu_seqlens_padded[i - 1]
                + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2)
            )
    else:
        cu_seqlens_padded = cu_seqlens
    cu_seqlens_padded = torch.tensor(
        cu_seqlens_padded,
178
179
180
181
        dtype=torch.int32,
        device=device,
    )
    t = torch.rand(
182
        (cu_seqlens_padded[-1] // cp_size, head_num, hidden_size),
183
184
185
186
187
188
189
        dtype=dtype,
        device=device,
    )
    if transpose:
        t = t.transpose(*transpose).contiguous().transpose(*transpose)
    t.requires_grad = True

190
    rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
191
    emb = rotary_pos_emb(cu_seqlens_padded[-1])
192
    assert emb.is_contiguous()
193
194
195
196
197

    for cp_rank in range(cp_size):
        # unfused
        # The fused kernel computes in float32 internally, so we force the unfused func to use float32
        # for more accurate comparison
198
199
200
        output_unfused = apply_rotary_pos_emb(
            t.float(),
            emb,
Sudhakar Singh's avatar
Sudhakar Singh committed
201
            start_positions=start_positions,
202
203
204
205
206
207
            tensor_format="thd",
            interleaved=interleaved,
            fused=False,
            cu_seqlens=cu_seqlens_padded,
            cp_size=cp_size,
            cp_rank=cp_rank,
208
209
        ).to(dtype)
        loss_unfused = loss_func(output_unfused)
Sudhakar Singh's avatar
Sudhakar Singh committed
210
211
212
213

        if not isinstance(start_positions, torch.Tensor):
            loss_unfused.backward()
            grad_unfused = t.grad.detach().clone()
214
215
216
217
218
219
        t.grad = None

        # fused
        output_fused = apply_rotary_pos_emb(
            t,
            emb,
Sudhakar Singh's avatar
Sudhakar Singh committed
220
            start_positions=start_positions,
221
            interleaved=interleaved,
222
223
224
225
226
227
228
            fused=True,
            tensor_format="thd",
            cu_seqlens=cu_seqlens_padded,
            cp_size=cp_size,
            cp_rank=cp_rank,
        )
        loss_fused = loss_func(output_fused)
Sudhakar Singh's avatar
Sudhakar Singh committed
229
230
231
232

        if not isinstance(start_positions, torch.Tensor):
            loss_fused.backward()
            grad_fused = t.grad.detach().clone()
233
234
235
        t.grad = None

        torch.testing.assert_close(output_fused, output_unfused)
Sudhakar Singh's avatar
Sudhakar Singh committed
236
237
238
239
240

        if not isinstance(start_positions, torch.Tensor):
            torch.testing.assert_close(grad_fused, grad_unfused)

        assert output_fused.is_contiguous()