import pytest import torch from torch import nn from lmdeploy.pytorch.kernels.fused_rotary_emb import fused_rotary_emb class DummyRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base**(torch.arange( 0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer('inv_freq', inv_freq, persistent=False) def forward(self, x, position_ids, seq_len=None): inv_freq_expanded = self.inv_freq[None, :, None].float().expand( position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=x.dtype) sin = emb.sin().to(dtype=x.dtype) # backwards compatibility return cos, sin class DummyLinearScalingRotaryEmbedding(DummyRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def forward(self, x, position_ids, seq_len=None): position_ids = position_ids.float() / self.scaling_factor cos, sin = super().forward(x, position_ids, seq_len) return cos, sin def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=2): """Applies Rotary Position Embedding to the query and key tensors.""" cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class TestFusedRotaryEmb: @pytest.fixture def dtype(self): yield torch.float16 @pytest.fixture def batch_size(self): yield 2 @pytest.fixture def head_dim(self): yield 64 @pytest.fixture def q_num_heads(self): yield 4 @pytest.fixture def k_num_heads(self): yield 2 @pytest.fixture def seq_len(self): yield 100 @pytest.fixture def q(self, batch_size, seq_len, q_num_heads, head_dim, dtype): yield torch.rand(batch_size, seq_len, q_num_heads, head_dim, dtype=dtype).to('cuda') @pytest.fixture def k(self, batch_size, seq_len, k_num_heads, head_dim, dtype): yield torch.rand(batch_size, seq_len, k_num_heads, head_dim, dtype=dtype).to('cuda') @pytest.fixture def position_ids(self, batch_size, seq_len): yield torch.randint(0, seq_len + 100, (batch_size, seq_len)).cuda() @pytest.fixture def rotary_emb(self, head_dim): yield DummyLinearScalingRotaryEmbedding(head_dim, scaling_factor=1.0).to('cuda') @pytest.fixture def gt(self, q, k, position_ids, rotary_emb): with torch.inference_mode(): cos, sin = rotary_emb(q, position_ids) yield apply_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids) def test_fused_rotary_emb(self, q, k, position_ids, rotary_emb, gt): inv_freq = rotary_emb.inv_freq scaling_factor = rotary_emb.scaling_factor with torch.inference_mode(): outq, outk = fused_rotary_emb(q, k, position_ids, inv_freq, scaling_factor=scaling_factor) gtq, gtk = gt torch.testing.assert_close(outq, gtq, atol=1e-3, rtol=1e-5) torch.testing.assert_close(outk, gtk, atol=1e-3, rtol=1e-5)