test_fused_rope.py 16.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
from typing import Callable, Tuple, Union, List
5
import math
6
import torch
7
8
import pytest
from transformer_engine.pytorch.attention.rope import (
9
10
    RotaryPositionEmbedding,
    apply_rotary_pos_emb,
11
    apply_fused_qkv_rotary_pos_emb,
12
13
14
15
)


# Gradient is a broadcasted scalar
16
17
18
19
20
def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
    if isinstance(output, List):
        return sum(t.sum() * 2 for t in output)
    else:
        return output.sum() * 2
21

22

23
# Gradient is a full tensor
24
25
26
27
28
29
def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
    if isinstance(output, List):
        return sum(torch.sum(t * torch.ones_like(t)) for t in output)
    else:
        t = torch.ones_like(output)
        return torch.sum(output * t)
30
31


Sudhakar Singh's avatar
Sudhakar Singh committed
32
@pytest.mark.parametrize("start_positions", [True, False])
33
34
35
36
37
38
39
40
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
41
42
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
43
44
45
46
47
48
49
50
51
def test_fused_rope(
    dtype: torch.dtype,
    seq_length: int,
    hidden_size: int,
    rotary_percent: float,
    margin: int,
    transpose: Union[Tuple, None],
    tensor_format: str,
    loss_func: Callable,
52
53
    cp_size: int,
    interleaved: bool,
Sudhakar Singh's avatar
Sudhakar Singh committed
54
    start_positions: bool,
55
) -> None:
Sudhakar Singh's avatar
Sudhakar Singh committed
56
57
58
59
60
    if margin == 0 and start_positions == True:
        # This makes sure that the `start_positions` offsets being applied
        # are with the maximum length of the rope embeddings.
        pytest.skip("Skipping test with margin=0 and start_positions=True")

61
62
63
64
65
66
67
    device = torch.device("cuda:0")
    batch_size, head_num = 2, 64
    t = torch.rand(
        (seq_length - margin, batch_size, head_num, hidden_size),
        dtype=dtype,
        device=device,
    )
Sudhakar Singh's avatar
Sudhakar Singh committed
68
69
70
71
72
73
74
75

    # Get arbitrary offsets to be used with RoPE for all the sequences
    start_positions = (
        torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
        if start_positions
        else None
    )

76
77
78
79
80
81
    if tensor_format == "bshd":
        t = t.transpose(0, 1).contiguous()
    if transpose:
        t = t.transpose(*transpose).contiguous().transpose(*transpose)
    t.requires_grad = True

82
83
84
    rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
    emb = rotary_pos_emb(seq_length * cp_size)
    assert emb.is_contiguous()
85

86
87
88
89
90
91
92
93
    for cp_rank in range(cp_size):
        # unfused
        # The fused kernel computes in float32 internally, so we force the unfused func to use float32
        # for more accurate comparison
        output_unfused = apply_rotary_pos_emb(
            t.float(),
            emb,
            tensor_format=tensor_format,
Sudhakar Singh's avatar
Sudhakar Singh committed
94
            start_positions=start_positions,
95
96
97
98
99
100
            interleaved=interleaved,
            fused=False,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)
        loss_unfused = loss_func(output_unfused)
101
102
        loss_unfused.backward()
        grad_unfused = t.grad.detach().clone()
103
104
105
106
107
108
109
        t.grad = None

        # fused
        output_fused = apply_rotary_pos_emb(
            t,
            emb,
            tensor_format=tensor_format,
Sudhakar Singh's avatar
Sudhakar Singh committed
110
            start_positions=start_positions,
111
112
113
114
115
116
            interleaved=interleaved,
            fused=True,
            cp_size=cp_size,
            cp_rank=cp_rank,
        )
        loss_fused = loss_func(output_fused)
117
118
        loss_fused.backward()
        grad_fused = t.grad.detach().clone()
119
120
121
        t.grad = None

        torch.testing.assert_close(output_fused, output_unfused)
122
        torch.testing.assert_close(grad_fused, grad_unfused)
123
        assert output_fused.is_contiguous()
124
125


Sudhakar Singh's avatar
Sudhakar Singh committed
126
127
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
128
129
130
131
132
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
133
134
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
135
136
137
138
139
140
def test_fused_rope_thd(
    dtype: torch.dtype,
    hidden_size: int,
    rotary_percent: float,
    transpose: Union[Tuple, None],
    loss_func: Callable,
141
    cp_size: int,
142
    interleaved: bool,
Sudhakar Singh's avatar
Sudhakar Singh committed
143
144
    start_positions: bool,
    margin: int,
145
) -> None:
Sudhakar Singh's avatar
Sudhakar Singh committed
146

147
148
    device = torch.device("cuda:0")
    batch_size, head_num = 2, 64
149
    cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
Sudhakar Singh's avatar
Sudhakar Singh committed
150
151
152
153
154
155
156
157

    # Get arbitrary offsets to be used with RoPE for all the sequences
    start_positions = (
        torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
        if start_positions
        else None
    )

158
159
160
161
162
163
164
165
166
167
168
    if cp_size > 1:
        cu_seqlens_padded = [0]
        for i in range(1, len(cu_seqlens)):
            cu_seqlens_padded.append(
                cu_seqlens_padded[i - 1]
                + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2)
            )
    else:
        cu_seqlens_padded = cu_seqlens
    cu_seqlens_padded = torch.tensor(
        cu_seqlens_padded,
169
170
171
172
        dtype=torch.int32,
        device=device,
    )
    t = torch.rand(
173
        (cu_seqlens_padded[-1] // cp_size, head_num, hidden_size),
174
175
176
177
178
179
180
        dtype=dtype,
        device=device,
    )
    if transpose:
        t = t.transpose(*transpose).contiguous().transpose(*transpose)
    t.requires_grad = True

181
    rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
182
    emb = rotary_pos_emb(cu_seqlens_padded[-1])
183
    assert emb.is_contiguous()
184
185
186
187
188

    for cp_rank in range(cp_size):
        # unfused
        # The fused kernel computes in float32 internally, so we force the unfused func to use float32
        # for more accurate comparison
189
190
191
        output_unfused = apply_rotary_pos_emb(
            t.float(),
            emb,
Sudhakar Singh's avatar
Sudhakar Singh committed
192
            start_positions=start_positions,
193
194
195
196
197
198
            tensor_format="thd",
            interleaved=interleaved,
            fused=False,
            cu_seqlens=cu_seqlens_padded,
            cp_size=cp_size,
            cp_rank=cp_rank,
199
200
        ).to(dtype)
        loss_unfused = loss_func(output_unfused)
201
202
        loss_unfused.backward()
        grad_unfused = t.grad.detach().clone()
203
204
205
206
207
208
        t.grad = None

        # fused
        output_fused = apply_rotary_pos_emb(
            t,
            emb,
Sudhakar Singh's avatar
Sudhakar Singh committed
209
            start_positions=start_positions,
210
            interleaved=interleaved,
211
212
213
214
215
216
217
            fused=True,
            tensor_format="thd",
            cu_seqlens=cu_seqlens_padded,
            cp_size=cp_size,
            cp_rank=cp_rank,
        )
        loss_fused = loss_func(output_fused)
218
219
        loss_fused.backward()
        grad_fused = t.grad.detach().clone()
220
221
222
        t.grad = None

        torch.testing.assert_close(output_fused, output_unfused)
223
224
        torch.testing.assert_close(grad_fused, grad_unfused)
        assert output_fused.is_contiguous()
Sudhakar Singh's avatar
Sudhakar Singh committed
225
226


227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
@pytest.mark.parametrize("start_positions", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [1.0])
@pytest.mark.parametrize("loss_func", [_overlapping_grad])
@pytest.mark.parametrize("cp_size", [2])
@pytest.mark.parametrize("interleaved", [False, True])
def test_unfused_rope_thd_vs_bshd(
    dtype: torch.dtype,
    hidden_size: int,
    rotary_percent: float,
    loss_func: Callable,
    cp_size: int,
    interleaved: bool,
    start_positions: bool,
) -> None:
    """
    This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD
    formats are the same.
    """
    device = torch.device("cuda:0")
    seqlen, max_seqlen = 16, 2048
    batch_size, head_num = 4, 256

    # NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and
    # that causes unexpected issues.
    seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32)

    cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to(
        device=device, dtype=torch.int32
    )

    # Create a tensor in THD format
    thd = torch.rand(
        (cu_seqlens[-1] // cp_size, head_num, hidden_size),
        dtype=dtype,
        device=device,
    )
    thd.requires_grad = True

    # Clone the tensor to create a tensor in BSHD format
    bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach()
    bshd = bshd.to(dtype=dtype, device=device)
    bshd.requires_grad = True

    # Clone the tensor to create a tensor in SBHD format
    sbhd = bshd.transpose(1, 0).clone().detach()
    sbhd = sbhd.to(dtype=dtype, device=device)
    sbhd.requires_grad = True

    rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
    emb = rotary_pos_emb(max_seqlen)
    assert emb.is_contiguous()

    start_positions = cu_seqlens[:-1] if start_positions else None

    for cp_rank in range(cp_size):
        # unfused bshd
        output_unfused_bshd = apply_rotary_pos_emb(
            bshd.float(),
            emb,
            start_positions=start_positions,
            interleaved=interleaved,
            fused=False,
            tensor_format="bshd",
            cu_seqlens=cu_seqlens,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)
        loss_unfused_bshd = loss_func(output_unfused_bshd)
        loss_unfused_bshd.backward()
        grad_unfused_bshd = bshd.grad.detach().clone()
        bshd.grad = None

        # unfused sbhd
        output_unfused_sbhd = apply_rotary_pos_emb(
            sbhd.float(),
            emb,
            start_positions=start_positions,
            interleaved=interleaved,
            fused=False,
            tensor_format="sbhd",
            cu_seqlens=cu_seqlens,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)

        loss_unfused_sbhd = loss_func(output_unfused_sbhd)
        loss_unfused_sbhd.backward()
        grad_unfused_sbhd = sbhd.grad.detach().clone()
        sbhd.grad = None

        # unfused thd
        output_unfused_thd = apply_rotary_pos_emb(
            thd.float(),
            emb,
            start_positions=start_positions,
            tensor_format="thd",
            interleaved=interleaved,
            fused=False,
            cu_seqlens=cu_seqlens,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)

        loss_unfused_thd = loss_func(output_unfused_thd)
        loss_unfused_thd.backward()
        grad_unfused_thd = thd.grad.detach().clone()
        thd.grad = None

        torch.testing.assert_close(
            output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd
        )
        torch.testing.assert_close(
            output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape),
            output_unfused_thd,
        )
        torch.testing.assert_close(
            grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd
        )
        torch.testing.assert_close(
            grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd
        )

        assert output_unfused_thd.is_contiguous()
        assert output_unfused_bshd.is_contiguous()
        assert output_unfused_sbhd.is_contiguous()
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481


@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_qkv_rope(
    dtype: torch.dtype,
    seq_length: int,
    hidden_size: int,
    rotary_percent: float,
    margin: int,
    tensor_format: str,
    loss_func: Callable,
    cp_size: int,
    interleaved: bool,
    start_positions: bool,
) -> None:
    if margin == 0 and start_positions == True:
        # This makes sure that the `start_positions` offsets being applied
        # are with the maximum length of the rope embeddings.
        pytest.skip("Skipping test with margin=0 and start_positions=True")

    if start_positions == True and cp_size > 1:
        # `start_positions` is only supported for `cp_size=1` and inference.
        pytest.skip("Skipping test with cp_size>1 and start_positions=True")

    if seq_length - margin < 0:
        pytest.skip("Skipping test with seq_length - margin < 0")

    device = torch.device("cuda:0")
    batch_size, head_num = 2, 64

    t = torch.rand(
        (seq_length - margin, batch_size, head_num, hidden_size * 6),
        dtype=dtype,
        device=device,
    )

    # Get arbitrary offsets to be used with RoPE for all the sequences
    start_positions = (
        torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
        if start_positions
        else None
    )

    if tensor_format == "bshd":
        t = t.transpose(0, 1).contiguous()
    t.requires_grad = True

    rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
    emb_q = rotary_pos_emb_q(seq_length * cp_size)
    rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
    emb_k = rotary_pos_emb_k(seq_length * cp_size)

    for cp_rank in range(cp_size):
        # unfused
        # The fused kernel computes in float32 internally, so we force the unfused func to use float32
        # for more accurate comparison

        t_clone = t.clone()
        (query, key, value) = torch.split(
            t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
        )
        query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)

        query_unfused = apply_rotary_pos_emb(
            query,
            emb_q,
            tensor_format=tensor_format,
            start_positions=start_positions,
            interleaved=interleaved,
            fused=True,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)

        key_unfused = apply_rotary_pos_emb(
            key,
            emb_k,
            tensor_format=tensor_format,
            start_positions=start_positions,
            interleaved=interleaved,
            fused=True,
            cp_size=cp_size,
            cp_rank=cp_rank,
        ).to(dtype)

        value_unfused = value
        loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])

        if not isinstance(start_positions, torch.Tensor):
            loss_unfused.backward()
            grad_unfused = t.grad.detach().clone()

        t.grad = None

        # fused
        query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
            t,
            emb_q,
            emb_k,
            tensor_format=tensor_format,
            start_positions=start_positions,
            interleaved=interleaved,
            cp_size=cp_size,
            cp_rank=cp_rank,
            qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
        )
        loss_fused = loss_func([query_fused, key_fused, value_fused])

        if not isinstance(start_positions, torch.Tensor):
            loss_fused.backward()
            grad_fused = t.grad.detach().clone()
        t.grad = None

        torch.testing.assert_close(query_fused, query_unfused)
        torch.testing.assert_close(key_fused, key_unfused)
        torch.testing.assert_close(value_fused, value_unfused)

        if not isinstance(start_positions, torch.Tensor):
            torch.testing.assert_close(grad_fused, grad_unfused)
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497


def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast():
    rope_layer = RotaryPositionEmbedding(128)

    rope_embeddings_no_autocast = rope_layer(max_seq_len=1024)

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        rope_embeddings_autocast = rope_layer(max_seq_len=1024)

    torch.testing.assert_close(
        rope_embeddings_no_autocast.to(dtype=torch.bfloat16),
        rope_embeddings_autocast.to(dtype=torch.bfloat16),
        atol=1e-8,
        rtol=1e-8,
    )