Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c18f88c6
Unverified
Commit
c18f88c6
authored
Nov 06, 2025
by
Jiangyun Zhu
Committed by
GitHub
Nov 05, 2025
Browse files
[Kernel] Fuse computation of g and beta for Gated Delta Net (#28095)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
6fd0df81
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
9 deletions
+30
-9
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+30
-9
No files found.
vllm/model_executor/models/qwen3_next.py
View file @
c18f88c6
...
...
@@ -551,10 +551,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec
)
beta
=
b
.
sigmoid
()
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
g
=
fused_gdn_gating
(
self
.
A_log
,
a
,
self
.
dt_bias
)
g
,
beta
=
map
(
lambda
x
:
rearrange
(
x
,
"l d -> 1 l d"
),
(
g
,
beta
))
g
,
beta
=
fused_gdn_gating
(
self
.
A_log
,
a
,
b
,
self
.
dt_bias
)
if
spec_sequence_masks
is
not
None
:
if
attn_metadata
.
num_prefills
==
0
and
attn_metadata
.
num_decodes
==
0
:
...
...
@@ -1289,12 +1286,13 @@ direct_register_custom_op(
)
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@
triton
.
jit
def
fused_gdn_gating_kernel
(
g
,
beta_output
,
A_log
,
a
,
b
,
dt_bias
,
seq_len
,
NUM_HEADS
:
tl
.
constexpr
,
...
...
@@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
mask
=
head_off
<
NUM_HEADS
blk_A_log
=
tl
.
load
(
A_log
+
head_off
,
mask
=
mask
)
blk_a
=
tl
.
load
(
a
+
off
,
mask
=
mask
)
blk_b
=
tl
.
load
(
b
+
off
,
mask
=
mask
)
blk_bias
=
tl
.
load
(
dt_bias
+
head_off
,
mask
=
mask
)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x
=
blk_a
.
to
(
tl
.
float32
)
+
blk_bias
.
to
(
tl
.
float32
)
...
...
@@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
)
blk_g
=
-
tl
.
exp
(
blk_A_log
.
to
(
tl
.
float32
))
*
softplus_x
tl
.
store
(
g
+
off
,
blk_g
.
to
(
g
.
dtype
.
element_ty
),
mask
=
mask
)
# compute beta_output = sigmoid(b)
blk_beta
=
1.0
/
(
1.0
+
tl
.
exp
(
-
blk_b
.
to
(
tl
.
float32
)))
tl
.
store
(
beta_output
+
off
,
blk_beta
.
to
(
beta_output
.
dtype
.
element_ty
),
mask
=
mask
)
def
fused_gdn_gating
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
beta
:
float
=
1.0
,
threshold
:
float
=
20.0
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Fused computation of g and beta for Gated Delta Net.
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
beta_output = b.sigmoid()
TODO maybe use torch.compile to replace this triton kernel
"""
batch
,
num_heads
=
a
.
shape
seq_len
=
1
grid
=
(
batch
,
seq_len
,
triton
.
cdiv
(
num_heads
,
8
))
g
=
torch
.
empty_like
(
a
,
dtype
=
torch
.
float32
)
g
=
torch
.
empty
(
1
,
batch
,
num_heads
,
dtype
=
torch
.
float32
,
device
=
a
.
device
)
beta_output
=
torch
.
empty
(
1
,
batch
,
num_heads
,
dtype
=
torch
.
float32
,
device
=
b
.
device
)
fused_gdn_gating_kernel
[
grid
](
g
,
A_log
,
a
,
dt_bias
,
seq_len
,
num_heads
,
beta
,
threshold
,
8
,
num_warps
=
1
g
,
beta_output
,
A_log
,
a
,
b
,
dt_bias
,
seq_len
,
num_heads
,
beta
,
threshold
,
8
,
num_warps
=
1
,
)
return
g
return
g
,
beta_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment