# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import pytest import torch from typing import Callable, Dict, Tuple, Union from transformer_engine.pytorch.attention import ( RotaryPositionEmbedding, apply_rotary_pos_emb, ) def apply_rotary_pos_emb_thd( t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. Args: t (Tensor): Input tensor T is of shape [t, h, d] cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return torch.cat( [ apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens) ] ).squeeze(1) def get_tol(dtype: torch.dtype) -> Dict: if dtype == torch.bfloat16: return dict(atol=1e-2, rtol=1e-2) elif dtype == torch.float16: return dict(atol=1e-3, rtol=1e-3) return dict(atol=1e-5, rtol=1.3e-6) # Gradient is a broadcasted scalar def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 # Gradient is a full tensor def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: t = torch.ones_like(output) return torch.sum(output * t) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("seq_length", [2048, 4096]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("margin", [0, 10]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) def test_fused_rope( dtype: torch.dtype, seq_length: int, hidden_size: int, rotary_percent: float, margin: int, transpose: Union[Tuple, None], tensor_format: str, loss_func: Callable, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( (seq_length - margin, batch_size, head_num, hidden_size), dtype=dtype, device=device, ) if tensor_format == "bshd": t = t.transpose(0, 1).contiguous() if transpose: t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(seq_length) # unfused output_unfused = apply_rotary_pos_emb( t, emb, tensor_format=tensor_format, fused=False ) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() t.grad = None # fused output_fused = apply_rotary_pos_emb( t, emb, tensor_format=tensor_format, fused=True, ) loss_fused = loss_func(output_fused) loss_fused.backward() grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) assert output_fused.is_contiguous() @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, rotary_percent: float, transpose: Union[Tuple, None], loss_func: Callable, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 cu_seqlens = torch.tensor( [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048], dtype=torch.int32, device=device, ) t = torch.rand( (cu_seqlens[-1], head_num, hidden_size), dtype=dtype, device=device, ) if transpose: t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(cu_seqlens[-1]) # unfused output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() t.grad = None # fused output_fused = apply_rotary_pos_emb( t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens ) loss_fused = loss_func(output_fused) loss_fused.backward() grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))