from typing import NamedTuple import pytest import torch from packaging.version import Version from transformers import AutoConfig from transformers import __version__ as TRANSFORMERS_VERSION from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.utils import ( cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu, is_xpu, ) _is_cuda = is_cuda() _is_hip = is_hip() _is_cpu = is_cpu() _is_cpu_amx_available = cpu_has_amx_support() _is_npu = is_npu() _is_xpu = is_xpu() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def generate_test_data( num_tokens: int, num_q_heads: int, num_kv_heads: int, head_size: int, max_position_embeddings: int, dtype: torch.dtype, device: torch.device, ): """Generate test data for given configuration.""" torch.manual_seed(42) # Create 2D positions (3, num_tokens) for multimodal case positions = torch.randint( 0, max_position_embeddings // 4, (3, num_tokens), device=device ) # Create query and key tensors query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) return positions, query, key class MRoPETestInfo(NamedTuple): model_name: str atol: float = 1e-2 rtol: float = 1.6e-2 marks: list[pytest.MarkDecorator] = [] TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version MODELS_TO_TEST = [ MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), ] num_tokens_list = [11, 8192] @pytest.mark.skipif(not _is_cuda, reason="Skipping CUDA/ROCm only tests.") @pytest.mark.parametrize( "model_info, model_name", [ pytest.param(test_config, test_config.model_name, marks=test_config.marks) for test_config in MODELS_TO_TEST ], ) @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("num_tokens", num_tokens_list) def test_mrope( model_name: str, model_info: MRoPETestInfo, tp_size: int, dtype: torch.dtype, num_tokens: int, ): atol = model_info.atol rtol = model_info.rtol config = AutoConfig.from_pretrained(model_name) config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // total_num_heads ) is_neox_style = True rope_theta = config.rope_theta max_position = config.max_position_embeddings partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, rotary_dim=rotary_dim, max_position=max_position, base=rope_theta, is_neox_style=is_neox_style, rope_scaling=config.rope_scaling, dtype=dtype, ).to(device=device) # create q k v input tensors # create rotary pos emb input tensors positions, query, key = generate_test_data( num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device ) query_native, key_native = mrope_helper_class.forward_native( positions, query.clone(), key.clone(), ) query_cuda, key_cuda = mrope_helper_class.forward( positions, query.clone(), key.clone(), ) torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol) torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)