Unverified Commit 61305291 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

Quick Fix: fix Qwen3-VL launch failure caused by MRotaryEmbedding arg (#10985)

parent a9ce2bcb
...@@ -1065,6 +1065,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1065,6 +1065,7 @@ class MRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward(). """PyTorch-native implementation equivalent to forward().
...@@ -1075,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1075,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
query: [num_tokens, num_heads * head_size] query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size] key: [num_tokens, num_kv_heads * head_size]
""" """
assert (
fused_set_kv_buffer_arg is None
), "save kv cache is not supported for MRotaryEmbedding."
assert positions.ndim == 1 or positions.ndim == 2 assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1] num_tokens = positions.shape[-1]
......
...@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE ...@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -358,6 +358,10 @@ class Qwen3MoeAttention(nn.Module): ...@@ -358,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.compatible_with_fused_kv_buffer = (
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
)
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -427,6 +431,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -427,6 +431,7 @@ class Qwen3MoeAttention(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if enable_fused_set_kv_buffer(forward_batch) if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None else None
), ),
) )
...@@ -439,7 +444,10 @@ class Qwen3MoeAttention(nn.Module): ...@@ -439,7 +444,10 @@ class Qwen3MoeAttention(nn.Module):
return hidden_states return hidden_states
attn_output = self.attn( attn_output = self.attn(
*inner_state, *inner_state,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), save_kv_cache=not (
enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
),
) )
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment