Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
285bab47
Unverified
Commit
285bab47
authored
Feb 10, 2026
by
Jiangyun Zhu
Committed by
GitHub
Feb 09, 2026
Browse files
[Kernel] use flashinfer for gdn prefill (#32846)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
995bbf38
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
2 deletions
+115
-2
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+115
-2
No files found.
vllm/model_executor/models/qwen3_next.py
View file @
285bab47
...
@@ -28,11 +28,15 @@ from vllm.distributed import (
...
@@ -28,11 +28,15 @@ from vllm.distributed import (
)
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fla.ops
import
(
from
vllm.model_executor.layers.fla.ops
import
(
chunk_gated_delta_rule
,
chunk_gated_delta_rule
as
fla_chunk_gated_delta_rule
,
)
from
vllm.model_executor.layers.fla.ops
import
(
fused_recurrent_gated_delta_rule
,
fused_recurrent_gated_delta_rule
,
)
)
from
vllm.model_executor.layers.fla.ops.chunk
import
l2norm_fwd
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.layernorm
import
(
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3NextRMSNorm
,
GemmaRMSNorm
as
Qwen3NextRMSNorm
,
...
@@ -101,6 +105,113 @@ logger = init_logger(__name__)
...
@@ -101,6 +105,113 @@ logger = init_logger(__name__)
KVCache
=
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
fi_chunk_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
):
from
flashinfer.gdn_prefill
import
(
chunk_gated_delta_rule
as
chunk_gated_delta_rule_fi
,
)
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
)
k
=
l2norm_fwd
(
k
)
# use flashinfer implementation
q
=
q
.
squeeze
(
0
).
contiguous
()
k
=
k
.
squeeze
(
0
).
contiguous
()
v
=
v
.
squeeze
(
0
).
contiguous
()
g
=
g
.
squeeze
(
0
).
contiguous
()
beta
=
beta
.
squeeze
(
0
).
contiguous
()
fi_state
=
initial_state
.
to
(
torch
.
float32
)
fi_g
=
g
.
to
(
torch
.
float32
)
fi_beta
=
beta
.
to
(
torch
.
float32
)
return
chunk_gated_delta_rule_fi
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
torch
.
exp
(
fi_g
),
beta
=
fi_beta
,
initial_state
=
fi_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
@
CustomOp
.
register
(
"chunk_gated_delta_rule"
)
class
ChunkGatedDeltaRule
(
CustomOp
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
if
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
90
):
logger
.
info_once
(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
)
self
.
_forward_method
=
self
.
forward_cuda
else
:
self
.
_forward_method
=
self
.
forward_native
def
forward_cuda
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
):
return
fi_chunk_gated_delta_rule
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
def
forward_native
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
):
return
fla_chunk_gated_delta_rule
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
class
Qwen3NextSparseMoeBlock
(
nn
.
Module
):
class
Qwen3NextSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -362,6 +473,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -362,6 +473,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix
=
f
"
{
prefix
}
.out_proj"
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
)
self
.
chunk_gated_delta_rule
=
ChunkGatedDeltaRule
()
compilation_config
=
get_current_vllm_config
().
compilation_config
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
...
@@ -647,7 +760,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -647,7 +760,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
(
(
core_attn_out_non_spec
,
core_attn_out_non_spec
,
last_recurrent_state
,
last_recurrent_state
,
)
=
chunk_gated_delta_rule
(
)
=
self
.
chunk_gated_delta_rule
(
q
=
query_non_spec
,
q
=
query_non_spec
,
k
=
key_non_spec
,
k
=
key_non_spec
,
v
=
value_non_spec
,
v
=
value_non_spec
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment