Unverified Commit 61305291 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

Quick Fix: fix Qwen3-VL launch failure caused by MRotaryEmbedding arg (#10985)

parent a9ce2bcb
......@@ -1065,6 +1065,7 @@ class MRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().
......@@ -1075,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert (
fused_set_kv_buffer_arg is None
), "save kv cache is not supported for MRotaryEmbedding."
assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1]
......
......@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -358,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.compatible_with_fused_kv_buffer = (
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
......@@ -427,6 +431,7 @@ class Qwen3MoeAttention(nn.Module):
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None
),
)
......@@ -439,7 +444,10 @@ class Qwen3MoeAttention(nn.Module):
return hidden_states
attn_output = self.attn(
*inner_state,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
save_kv_cache=not (
enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
),
)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment