Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1740006
Unverified
Commit
f1740006
authored
Mar 17, 2026
by
Xin Yang
Committed by
GitHub
Mar 18, 2026
Browse files
[Perf] Enable dual stream execution of input projection for Qwen3 (#36795)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
58cde5c0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
115 additions
and
5 deletions
+115
-5
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+6
-2
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+61
-3
vllm/utils/multi_stream_utils.py
vllm/utils/multi_stream_utils.py
+48
-0
No files found.
vllm/model_executor/models/qwen3_5.py
View file @
f1740006
...
@@ -180,12 +180,16 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
...
@@ -180,12 +180,16 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# ============================================================
# Part 1: Input Projection
# Part 1: Input Projection
# ============================================================
# ============================================================
mixed_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
mixed_qkvz
,
ba
=
torch
.
ops
.
vllm
.
gdn_in_proj
(
hidden_states
,
self
.
in_proj_qkvz
.
weight
.
shape
[
0
],
self
.
in_proj_ba
.
weight
.
shape
[
0
],
self
.
prefix
,
)
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
z_size
=
self
.
value_dim
//
self
.
tp_size
z_size
=
self
.
value_dim
//
self
.
tp_size
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
ba
,
_
=
self
.
in_proj_ba
(
hidden_states
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
b
=
b
.
contiguous
()
b
=
b
.
contiguous
()
...
...
vllm/model_executor/models/qwen3_next.py
View file @
f1740006
...
@@ -82,7 +82,11 @@ from vllm.platforms import current_platform
...
@@ -82,7 +82,11 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
Qwen3NextConfig
from
vllm.transformers_utils.configs
import
Qwen3NextConfig
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.multi_stream_utils
import
maybe_execute_in_parallel
from
vllm.utils.torch_utils
import
(
aux_stream
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
...
@@ -419,6 +423,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -419,6 +423,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
aux_stream
=
aux_stream
()
self
.
events
=
(
[
torch
.
cuda
.
Event
(),
torch
.
cuda
.
Event
()]
if
current_platform
.
is_cuda
()
else
[
None
,
None
]
)
self
.
config
=
config
self
.
config
=
config
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
@@ -647,8 +657,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -647,8 +657,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# ============================================================
# ============================================================
# Part 1: Input Projection
# Part 1: Input Projection
# ============================================================
# ============================================================
projected_states_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
projected_states_qkvz
,
projected_states_ba
=
torch
.
ops
.
vllm
.
gdn_in_proj
(
projected_states_ba
,
_
=
self
.
in_proj_ba
(
hidden_states
)
hidden_states
,
self
.
in_proj_qkvz
.
weight
.
shape
[
0
],
self
.
in_proj_ba
.
weight
.
shape
[
0
],
self
.
prefix
,
)
query
,
key
,
value
,
z
,
b
,
a
=
self
.
fix_query_key_value_ordering
(
query
,
key
,
value
,
z
,
b
,
a
=
self
.
fix_query_key_value_ordering
(
projected_states_qkvz
,
projected_states_ba
projected_states_qkvz
,
projected_states_ba
)
)
...
@@ -783,6 +797,18 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -783,6 +797,18 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
torch
.
accelerator
.
empty_cache
()
torch
.
accelerator
.
empty_cache
()
def
_forward_in_proj
(
self
,
hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
projected_states_qkvz
,
projected_states_ba
=
maybe_execute_in_parallel
(
lambda
:
self
.
in_proj_qkvz
(
hidden_states
)[
0
],
lambda
:
self
.
in_proj_ba
(
hidden_states
)[
0
],
self
.
events
[
0
],
self
.
events
[
1
],
self
.
aux_stream
,
)
return
projected_states_qkvz
,
projected_states_ba
def
_forward_core
(
def
_forward_core
(
self
,
self
,
mixed_qkv
:
torch
.
Tensor
,
mixed_qkv
:
torch
.
Tensor
,
...
@@ -1670,6 +1696,32 @@ class Qwen3NextForCausalLM(
...
@@ -1670,6 +1696,32 @@ class Qwen3NextForCausalLM(
return
self
.
model
.
get_expert_mapping
()
return
self
.
model
.
get_expert_mapping
()
def
gdn_in_proj
(
hidden_states
:
torch
.
Tensor
,
qkvz_output_size
:
int
,
ba_output_size
:
int
,
layer_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Custom op for the input projection.
"""
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
return
self
.
_forward_in_proj
(
hidden_states
)
def
gdn_in_proj_fake
(
hidden_states
:
torch
.
Tensor
,
qkvz_output_size
:
int
,
ba_output_size
:
int
,
layer_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Fake implementation for torch.compile."""
return
hidden_states
.
new_empty
(
hidden_states
.
shape
[
0
],
qkvz_output_size
),
hidden_states
.
new_empty
(
hidden_states
.
shape
[
0
],
ba_output_size
)
def
gdn_attention_core
(
def
gdn_attention_core
(
mixed_qkv
:
torch
.
Tensor
,
mixed_qkv
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
@@ -1703,6 +1755,12 @@ def gdn_attention_core_fake(
...
@@ -1703,6 +1755,12 @@ def gdn_attention_core_fake(
return
return
direct_register_custom_op
(
op_name
=
"gdn_in_proj"
,
op_func
=
gdn_in_proj
,
fake_impl
=
gdn_in_proj_fake
,
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"gdn_attention_core"
,
op_name
=
"gdn_attention_core"
,
op_func
=
gdn_attention_core
,
op_func
=
gdn_attention_core
,
...
...
vllm/utils/multi_stream_utils.py
0 → 100644
View file @
f1740006
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
typing
import
Any
import
torch
def
maybe_execute_in_parallel
(
fn0
:
Callable
[[],
Any
],
fn1
:
Callable
[[],
Any
],
event0
:
torch
.
cuda
.
Event
,
event1
:
torch
.
cuda
.
Event
,
aux_stream
:
torch
.
cuda
.
Stream
|
None
=
None
,
)
->
tuple
[
Any
,
Any
]:
"""Run two functions potentially in parallel on separate CUDA streams.
When aux_stream is provided, fn0 runs on the current (default) stream and
fn1 runs on aux_stream, synchronized via CUDA events. When aux_stream is
None, both functions execute sequentially on the current stream.
This design follows TensorRT-LLM's maybe_execute_in_parallel pattern
(tensorrt_llm/_torch/modules/multi_stream_utils.py).
Args:
fn0: Callable for the default stream.
fn1: Callable for the auxiliary stream.
event0: CUDA event recorded before fn0 so aux_stream can wait.
event1: CUDA event recorded after fn1 so default stream can wait.
aux_stream: The second CUDA stream for fn1.
Multi-stream is disabled when aux_stream is None.
Returns:
Tuple of (fn0_result, fn1_result).
"""
if
aux_stream
is
not
None
:
event0
.
record
()
result0
=
fn0
()
with
torch
.
cuda
.
stream
(
aux_stream
):
event0
.
wait
()
result1
=
fn1
()
event1
.
record
()
event1
.
wait
()
else
:
result0
=
fn0
()
result1
=
fn1
()
return
(
result0
,
result1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment