Unverified Commit 9704a5c3 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

Disable dual stream execution of input projection for Qwen3 (#38152)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 74056039
...@@ -221,12 +221,8 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): ...@@ -221,12 +221,8 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
ba, _ = self.in_proj_ba(hidden_states) ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states) z, _ = self.in_proj_z(hidden_states)
else: else:
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
hidden_states, ba, _ = self.in_proj_ba(hidden_states)
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
......
...@@ -81,11 +81,7 @@ from vllm.platforms import current_platform ...@@ -81,11 +81,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
aux_stream,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
...@@ -421,12 +417,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -421,12 +417,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.act = ACT2FN[config.hidden_act] self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix self.prefix = prefix
self.aux_stream = aux_stream()
self.events = (
[torch.cuda.Event(), torch.cuda.Event()]
if current_platform.is_cuda_alike()
else [None, None]
)
self.config = config self.config = config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
...@@ -659,12 +649,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -659,12 +649,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# ============================================================ # ============================================================
# Part 1: Input Projection # Part 1: Input Projection
# ============================================================ # ============================================================
projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj( projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
hidden_states, projected_states_ba, _ = self.in_proj_ba(hidden_states)
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
query, key, value, z, b, a = self.fix_query_key_value_ordering( query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba projected_states_qkvz, projected_states_ba
) )
...@@ -804,18 +790,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -804,18 +790,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
def _forward_in_proj(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
lambda: self.in_proj_qkvz(hidden_states)[0],
lambda: self.in_proj_ba(hidden_states)[0],
self.events[0],
self.events[1],
self.aux_stream,
)
return projected_states_qkvz, projected_states_ba
def _forward_core( def _forward_core(
self, self,
mixed_qkv: torch.Tensor, mixed_qkv: torch.Tensor,
...@@ -1697,32 +1671,6 @@ class Qwen3NextForCausalLM( ...@@ -1697,32 +1671,6 @@ class Qwen3NextForCausalLM(
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
def gdn_in_proj(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Custom op for the input projection.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._forward_in_proj(hidden_states)
def gdn_in_proj_fake(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile."""
return hidden_states.new_empty(
hidden_states.shape[0], qkvz_output_size
), hidden_states.new_empty(hidden_states.shape[0], ba_output_size)
def gdn_attention_core( def gdn_attention_core(
mixed_qkv: torch.Tensor, mixed_qkv: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
...@@ -1756,12 +1704,6 @@ def gdn_attention_core_fake( ...@@ -1756,12 +1704,6 @@ def gdn_attention_core_fake(
return return
direct_register_custom_op(
op_name="gdn_in_proj",
op_func=gdn_in_proj,
fake_impl=gdn_in_proj_fake,
)
direct_register_custom_op( direct_register_custom_op(
op_name="gdn_attention_core", op_name="gdn_attention_core",
op_func=gdn_attention_core, op_func=gdn_attention_core,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment