Unverified Commit cea91a32 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent a684c012
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
import pytest import pytest
import torch import torch
from packaging.version import Version
from transformers import AutoConfig from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, ...@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size: int, max_position_embeddings: int, head_size: int, max_position_embeddings: int,
dtype: torch.dtype, device: torch.device): dtype: torch.dtype, device: torch.device):
"""Generate test data for given configuration.""" """Generate test data for given configuration."""
current_platform.seed_everything(42)
# Create 2D positions (3, num_tokens) for multimodal case # Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(0, positions = torch.randint(0,
max_position_embeddings // 4, (3, num_tokens), max_position_embeddings // 4, (3, num_tokens),
...@@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, ...@@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
return positions, query, key return positions, query, key
def unroll_model_tp_dict(model_tp_dict): class MRoPETestInfo(NamedTuple):
return [(model_name, tp_size) model_name: str
for model_name, tp_sizes in model_tp_dict.items() # https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
for tp_size in tp_sizes] atol: float = 1e-2
rtol: float = 1.6e-2
marks: list[pytest.MarkDecorator] = []
model_tp_dict = { TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
}
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 MODELS_TO_TEST = [
dtype_atol_rtol_list = [ MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
[torch.bfloat16, 1e-2, 1.6e-2], MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
] ]
num_tokens_list = [11, 8192] num_tokens_list = [11, 8192]
...@@ -56,20 +75,29 @@ num_tokens_list = [11, 8192] ...@@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
@pytest.mark.skipif(not current_platform.is_cuda_alike(), @pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.") reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_name, tp_size", @pytest.mark.parametrize("model_info, model_name", [
unroll_model_tp_dict(model_tp_dict)) pytest.param(test_config, test_config.model_name, marks=test_config.marks)
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list) @pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config # get the model config
total_num_kv_heads = config.num_key_value_heads total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size) num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True is_neox_style = True
rope_theta = config.rope_theta rope_theta = config.rope_theta
...@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): ...@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
@pytest.mark.skipif(not current_platform.is_cuda_alike(), @pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.") reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize( @pytest.mark.parametrize("model_info, model_name", [
"model_name, tp_size", pytest.param(test_config, test_config.model_name, marks=test_config.marks)
unroll_model_tp_dict({ for test_config in MODELS_TO_TEST
"Qwen/Qwen2-VL-7B-Instruct": [1, 2], ])
"zai-org/GLM-4.1V-9B-Thinking": [1, 2] @pytest.mark.parametrize("tp_size", [1, 2])
})) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) @pytest.mark.parametrize("num_tokens", num_tokens_list)
@pytest.mark.parametrize("num_tokens", [4]) def test_mrope_torch_compile_tracing(model_name: str,
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, model_info: MRoPETestInfo, tp_size: int,
num_tokens): dtype: torch.dtype, num_tokens: int):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config # get the model config
total_num_kv_heads = config.num_key_value_heads total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size) num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True is_neox_style = True
rope_theta = config.rope_theta rope_theta = config.rope_theta
max_position = config.max_position_embeddings max_position = config.max_position_embeddings
......
...@@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch ...@@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
@triton.jit @triton.jit
def _triton_qwen2vl_mrope_forward( def _triton_mrope_forward(
q_ptr, q_ptr,
k_ptr, k_ptr,
cos, cos,
...@@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward( ...@@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
pad_hd: tl.constexpr, pad_hd: tl.constexpr,
mrope_section_t: tl.constexpr, mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr, mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
): ):
# Adapted from # Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# This version supports flatten input tensors from vllm # This version supports flatten input tensors from vllm
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
# instead of (3, bsz, seq_len, head_dim) # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
pid = tl.program_id(0) pid = tl.program_id(0)
# locate start address # locate start address
q_ptr = q_ptr + pid * (n_qh * hd) q_ptr = q_ptr + pid * (n_qh * hd)
...@@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward( ...@@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
# #################################################################### # ####################################################################
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2) # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
t_end = mrope_section_t
h_end = t_end + mrope_section_h
# Updated stride calculation for half head_dim # Updated stride calculation for half head_dim
half_rd = rd // 2 half_rd = rd // 2
t_cos = cos + pid * half_rd t_cos = cos + pid * half_rd
...@@ -61,7 +60,16 @@ def _triton_qwen2vl_mrope_forward( ...@@ -61,7 +60,16 @@ def _triton_qwen2vl_mrope_forward(
# Updated offsets for half head_dim # Updated offsets for half head_dim
cos_offsets = tl.arange(0, pad_hd // 2) cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end if is_interleaved:
h_mask = (((cos_offsets % 3) == 1) &
(cos_offsets <= 3 * mrope_section_h))
w_mask = (((cos_offsets % 3) == 2) &
(cos_offsets <= 3 * mrope_section_w))
t_mask = ~(h_mask | w_mask)
else:
t_end = mrope_section_t
h_end = t_end + mrope_section_h
t_mask = cos_offsets < mrope_section_t
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
...@@ -131,6 +139,7 @@ def triton_mrope( ...@@ -131,6 +139,7 @@ def triton_mrope(
mrope_section: list[int], mrope_section: list[int],
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
mrope_interleaved: bool,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Qwen2VL mrope kernel. """Qwen2VL mrope kernel.
...@@ -158,7 +167,7 @@ def triton_mrope( ...@@ -158,7 +167,7 @@ def triton_mrope(
cos = cos.contiguous() cos = cos.contiguous()
sin = sin.contiguous() sin = sin.contiguous()
_triton_qwen2vl_mrope_forward[(n_row, )]( _triton_mrope_forward[(n_row, )](
q, q,
k, k,
cos, cos,
...@@ -173,6 +182,8 @@ def triton_mrope( ...@@ -173,6 +182,8 @@ def triton_mrope(
pad_hd, pad_hd,
mrope_section[0], mrope_section[0],
mrope_section[1], mrope_section[1],
mrope_section[2],
mrope_interleaved,
) )
return q, k return q, k
...@@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
mrope_section: Optional[list[int]] = None, mrope_section: Optional[list[int]] = None,
mrope_interleaved: Optional[bool] = False, mrope_interleaved: bool = False,
) -> None: ) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of # In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get # the input video. We enlarge max_position_embeddings to 4 times to get
...@@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding):
assert positions.ndim == 1 or positions.ndim == 2 assert positions.ndim == 1 or positions.ndim == 2
assert key is not None assert key is not None
if self.mrope_interleaved:
# TODO: add triton implementation to support mrope-interleaved
return self.forward_native(positions, query, key)
num_tokens = positions.shape[-1] num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions] cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
...@@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_section, self.mrope_section,
self.head_size, self.head_size,
self.rotary_dim, self.rotary_dim,
self.mrope_interleaved,
) )
return q.reshape(query_shape), k.reshape(key_shape) return q.reshape(query_shape), k.reshape(key_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment