# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import math import pytest import torch from typing import Callable, Tuple, Union from transformer_engine.pytorch.attention import ( RotaryPositionEmbedding, apply_rotary_pos_emb, ) def _get_thd_freqs_on_this_cp_rank( cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor ) -> torch.Tensor: if cp_size > 1: cp_seg = x.size(0) // 2 full_seqlen = cp_size * x.size(0) return torch.cat( [ freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], ] ) else: return freqs[: x.size(0)] def apply_rotary_pos_emb_thd( t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor, cp_size: int = 1, cp_rank: int = 0, ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. Args: t (Tensor): Input tensor T is of shape [t, h, d] cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return torch.cat( [ apply_rotary_pos_emb( x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) ) for x in torch.split(t, seqlens) ] ).squeeze(1) # Gradient is a broadcasted scalar def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 # Gradient is a full tensor def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: t = torch.ones_like(output) return torch.sum(output * t) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("seq_length", [2048, 4096]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("margin", [0, 10]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) def test_fused_rope( dtype: torch.dtype, seq_length: int, hidden_size: int, rotary_percent: float, margin: int, transpose: Union[Tuple, None], tensor_format: str, loss_func: Callable, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( (seq_length - margin, batch_size, head_num, hidden_size), dtype=dtype, device=device, ) if tensor_format == "bshd": t = t.transpose(0, 1).contiguous() if transpose: t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(seq_length) # unfused # The fused kernel computes in float32 internally, so we force the unfused func to use float32 # for more accurate comparison output_unfused = apply_rotary_pos_emb( t.float(), emb, tensor_format=tensor_format, fused=False ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() t.grad = None # fused output_fused = apply_rotary_pos_emb( t, emb, tensor_format=tensor_format, fused=True, ) loss_fused = loss_func(output_fused) loss_fused.backward() grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("cp_size", [1, 2, 3]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, rotary_percent: float, transpose: Union[Tuple, None], loss_func: Callable, cp_size: int, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] if cp_size > 1: cu_seqlens_padded = [0] for i in range(1, len(cu_seqlens)): cu_seqlens_padded.append( cu_seqlens_padded[i - 1] + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2) ) else: cu_seqlens_padded = cu_seqlens cu_seqlens_padded = torch.tensor( cu_seqlens_padded, dtype=torch.int32, device=device, ) t = torch.rand( (cu_seqlens_padded[-1] // cp_size, head_num, hidden_size), dtype=dtype, device=device, ) if transpose: t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(cu_seqlens_padded[-1]) for cp_rank in range(cp_size): # unfused # The fused kernel computes in float32 internally, so we force the unfused func to use float32 # for more accurate comparison output_unfused = apply_rotary_pos_emb_thd( t.float(), cu_seqlens_padded, emb, cp_size, cp_rank ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() t.grad = None # fused output_fused = apply_rotary_pos_emb( t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens_padded, cp_size=cp_size, cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) loss_fused.backward() grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(grad_fused, grad_unfused)