Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc6b5784
Unverified
Commit
dc6b5784
authored
Mar 08, 2026
by
Xin Yang
Committed by
GitHub
Mar 08, 2026
Browse files
[Kernel] Add fused_sigmoid_gating_delta_rule_update kernel for Qwen3 Next (#35777)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
1bc9c77f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
509 additions
and
31 deletions
+509
-31
tests/kernels/test_fused_sigmoid_gating_delta_rule.py
tests/kernels/test_fused_sigmoid_gating_delta_rule.py
+196
-0
vllm/model_executor/layers/fla/ops/__init__.py
vllm/model_executor/layers/fla/ops/__init__.py
+2
-0
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
+279
-0
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+32
-31
No files found.
tests/kernels/test_fused_sigmoid_gating_delta_rule.py
0 → 100644
View file @
dc6b5784
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.fla.ops
import
(
fused_recurrent_gated_delta_rule
,
fused_sigmoid_gating_delta_rule_update
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
DEVICE
=
current_platform
.
device_type
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_reqs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_k_heads"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_v_heads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"head_k_dim"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"head_v_dim"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_fused_sigmoid_gating_delta_rule_update_non_spec
(
tp_size
:
int
,
num_reqs
:
int
,
num_k_heads
:
int
,
num_v_heads
:
int
,
head_k_dim
:
int
,
head_v_dim
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
torch
.
set_default_device
(
DEVICE
)
set_random_seed
(
0
)
key_dim
=
head_k_dim
*
num_k_heads
value_dim
=
head_v_dim
*
num_v_heads
mixed_qkv_dim
=
(
key_dim
*
2
+
value_dim
)
//
tp_size
seq_len
=
1
# seq_len is 1 for decode
num_tokens
=
num_reqs
*
seq_len
total_entries
=
num_tokens
*
2
mixed_qkv
=
torch
.
rand
(
num_tokens
,
mixed_qkv_dim
,
dtype
=
dtype
)
query
,
key
,
value
=
torch
.
split
(
mixed_qkv
,
[
key_dim
//
tp_size
,
key_dim
//
tp_size
,
value_dim
//
tp_size
,
],
dim
=-
1
,
)
query
=
query
.
view
(
1
,
num_tokens
,
num_k_heads
,
head_k_dim
)
key
=
key
.
view
(
1
,
num_tokens
,
num_k_heads
,
head_k_dim
)
value
=
value
.
view
(
1
,
num_tokens
,
num_v_heads
,
head_v_dim
)
A_log
=
torch
.
rand
(
num_v_heads
//
tp_size
,
dtype
=
dtype
)
dt_bias
=
torch
.
rand
(
num_v_heads
//
tp_size
,
dtype
=
dtype
)
a
=
torch
.
rand
(
num_tokens
,
num_v_heads
,
dtype
=
dtype
)
b
=
torch
.
rand
(
num_tokens
,
num_v_heads
,
dtype
=
dtype
)
ssm_state
=
torch
.
rand
(
total_entries
,
num_v_heads
,
head_k_dim
,
head_v_dim
,
dtype
=
dtype
)
state_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
)[:
num_tokens
]
cu_seqlens
=
torch
.
arange
(
0
,
num_tokens
+
1
,
dtype
=
torch
.
int32
)
beta
=
b
.
sigmoid
()
g
=
-
A_log
.
float
().
exp
()
*
F
.
softplus
(
a
.
float
()
+
dt_bias
)
core_attn_out_ref
,
last_recurrent_state_ref
=
fused_recurrent_gated_delta_rule
(
q
=
query
,
k
=
key
,
v
=
value
,
g
=
g
.
unsqueeze
(
0
),
beta
=
beta
.
unsqueeze
(
0
),
initial_state
=
ssm_state
.
clone
(),
inplace_final_state
=
True
,
ssm_state_indices
=
state_indices
,
cu_seqlens
=
cu_seqlens
,
use_qk_l2norm_in_kernel
=
True
,
)
core_attn_out
,
last_recurrent_state
=
fused_sigmoid_gating_delta_rule_update
(
A_log
=
A_log
,
a
=
a
,
b
=
b
,
dt_bias
=
dt_bias
,
q
=
query
,
k
=
key
,
v
=
value
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
ssm_state_indices
=
state_indices
,
cu_seqlens
=
cu_seqlens
,
use_qk_l2norm_in_kernel
=
True
,
)
torch
.
testing
.
assert_close
(
core_attn_out
,
core_attn_out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
last_recurrent_state
,
last_recurrent_state_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_reqs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_k_heads"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_v_heads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"head_k_dim"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"head_v_dim"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_fused_sigmoid_gating_delta_rule_update_spec
(
tp_size
:
int
,
num_reqs
:
int
,
num_k_heads
:
int
,
num_v_heads
:
int
,
head_k_dim
:
int
,
head_v_dim
:
int
,
num_speculative_tokens
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
torch
.
set_default_device
(
DEVICE
)
set_random_seed
(
0
)
key_dim
=
head_k_dim
*
num_k_heads
value_dim
=
head_v_dim
*
num_v_heads
mixed_qkv_dim
=
(
key_dim
*
2
+
value_dim
)
//
tp_size
num_tokens
=
num_reqs
*
(
num_speculative_tokens
+
1
)
total_entries
=
num_tokens
*
2
mixed_qkv
=
torch
.
rand
(
num_tokens
,
mixed_qkv_dim
,
dtype
=
dtype
)
query
,
key
,
value
=
torch
.
split
(
mixed_qkv
,
[
key_dim
//
tp_size
,
key_dim
//
tp_size
,
value_dim
//
tp_size
,
],
dim
=-
1
,
)
query
=
query
.
view
(
1
,
num_tokens
,
num_k_heads
,
head_k_dim
)
key
=
key
.
view
(
1
,
num_tokens
,
num_k_heads
,
head_k_dim
)
value
=
value
.
view
(
1
,
num_tokens
,
num_v_heads
,
head_v_dim
)
A_log
=
torch
.
rand
(
num_v_heads
//
tp_size
,
dtype
=
dtype
)
dt_bias
=
torch
.
rand
(
num_v_heads
//
tp_size
,
dtype
=
dtype
)
a
=
torch
.
rand
(
num_tokens
,
num_v_heads
,
dtype
=
dtype
)
b
=
torch
.
rand
(
num_tokens
,
num_v_heads
,
dtype
=
dtype
)
ssm_state
=
torch
.
rand
(
total_entries
,
num_v_heads
,
head_k_dim
,
head_v_dim
,
dtype
=
dtype
)
state_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
)[:
num_tokens
].
view
(
num_reqs
,
num_speculative_tokens
+
1
)
num_accepted_tokens
=
torch
.
randint
(
1
,
num_speculative_tokens
+
1
,
(
num_reqs
,),
dtype
=
torch
.
int32
)
cu_seqlens
=
torch
.
arange
(
0
,
num_tokens
+
1
,
num_speculative_tokens
+
1
,
dtype
=
torch
.
int32
)
beta
=
b
.
sigmoid
()
g
=
-
A_log
.
float
().
exp
()
*
F
.
softplus
(
a
.
float
()
+
dt_bias
)
core_attn_out_ref
,
last_recurrent_state_ref
=
fused_recurrent_gated_delta_rule
(
q
=
query
,
k
=
key
,
v
=
value
,
g
=
g
.
unsqueeze
(
0
),
beta
=
beta
.
unsqueeze
(
0
),
initial_state
=
ssm_state
.
clone
(),
inplace_final_state
=
True
,
ssm_state_indices
=
state_indices
,
cu_seqlens
=
cu_seqlens
,
num_accepted_tokens
=
num_accepted_tokens
,
use_qk_l2norm_in_kernel
=
True
,
)
core_attn_out
,
last_recurrent_state
=
fused_sigmoid_gating_delta_rule_update
(
A_log
=
A_log
,
a
=
a
,
b
=
b
,
dt_bias
=
dt_bias
,
q
=
query
,
k
=
key
,
v
=
value
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
ssm_state_indices
=
state_indices
,
cu_seqlens
=
cu_seqlens
,
num_accepted_tokens
=
num_accepted_tokens
,
use_qk_l2norm_in_kernel
=
True
,
)
torch
.
testing
.
assert_close
(
core_attn_out
,
core_attn_out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
last_recurrent_state
,
last_recurrent_state_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
vllm/model_executor/layers/fla/ops/__init__.py
View file @
dc6b5784
...
@@ -8,10 +8,12 @@
...
@@ -8,10 +8,12 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
.chunk
import
chunk_gated_delta_rule
from
.chunk
import
chunk_gated_delta_rule
from
.fused_recurrent
import
fused_recurrent_gated_delta_rule
from
.fused_recurrent
import
fused_recurrent_gated_delta_rule
from
.fused_sigmoid_gating
import
fused_sigmoid_gating_delta_rule_update
from
.layernorm_guard
import
RMSNormGated
from
.layernorm_guard
import
RMSNormGated
__all__
=
[
__all__
=
[
"RMSNormGated"
,
"RMSNormGated"
,
"chunk_gated_delta_rule"
,
"chunk_gated_delta_rule"
,
"fused_recurrent_gated_delta_rule"
,
"fused_recurrent_gated_delta_rule"
,
"fused_sigmoid_gating_delta_rule_update"
,
]
]
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
0 → 100644
View file @
dc6b5784
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
heuristics
(
{
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"IS_CONTINUOUS_BATCHING"
:
lambda
args
:
args
[
"ssm_state_indices"
]
is
not
None
,
"IS_SPEC_DECODING"
:
lambda
args
:
args
[
"num_accepted_tokens"
]
is
not
None
,
}
)
@
triton
.
jit
(
do_not_specialize
=
[
"N"
,
"T"
])
def
fused_sigmoid_gating_delta_rule_update_kernel
(
A_log
,
a
,
b
,
dt_bias
,
beta
,
threshold
,
q
,
k
,
v
,
o
,
h0
,
ht
,
cu_seqlens
,
ssm_state_indices
,
num_accepted_tokens
,
scale
,
N
:
tl
.
int64
,
# num of sequences
T
:
tl
.
int64
,
# num of tokens
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
stride_init_state_token
:
tl
.
constexpr
,
stride_final_state_token
:
tl
.
constexpr
,
stride_indices_seq
:
tl
.
constexpr
,
stride_indices_tok
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
# whether to use initial state
INPLACE_FINAL_STATE
:
tl
.
constexpr
,
# whether to store final state inplace
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
IS_KDA
:
tl
.
constexpr
,
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
),
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
if
T
==
0
:
# no tokens to process for this sequence
return
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_k
=
k
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_v
=
v
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
p_A_log
=
A_log
+
i_hv
if
not
IS_KDA
:
p_a
=
a
+
bos
*
HV
+
i_hv
p_dt_bias
=
dt_bias
+
i_hv
else
:
p_a
=
a
+
(
bos
*
HV
+
i_hv
)
*
K
+
o_k
p_dt_bias
=
dt_bias
+
i_hv
*
K
+
o_k
p_b
=
b
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_v
[:,
None
]
&
mask_k
[
None
,
:]
b_h
=
tl
.
zeros
([
BV
,
BK
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
if
IS_CONTINUOUS_BATCHING
:
if
IS_SPEC_DECODING
:
i_t
=
tl
.
load
(
num_accepted_tokens
+
i_n
).
to
(
tl
.
int64
)
-
1
else
:
i_t
=
0
# Load state index and check for PAD_SLOT_ID (-1)
state_idx
=
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if
state_idx
<
0
:
return
p_h0
=
h0
+
state_idx
*
stride_init_state_token
else
:
p_h0
=
h0
+
bos
*
HV
*
V
*
K
p_h0
=
p_h0
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
for
i_t
in
range
(
0
,
T
):
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_b
=
tl
.
load
(
p_b
).
to
(
tl
.
float32
)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
+
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
softplus_x
=
tl
.
where
(
beta
*
x
<=
threshold
,
(
1
/
beta
)
*
tl
.
log
(
1
+
tl
.
exp
(
beta
*
x
)),
x
)
b_g
=
-
tl
.
exp
(
tl
.
load
(
p_A_log
).
to
(
tl
.
float32
))
*
softplus_x
# compute beta_output = sigmoid(b)
b_beta
=
tl
.
sigmoid
(
b_b
.
to
(
tl
.
float32
))
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
*
(
tl
.
rsqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
))
b_k
=
b_k
*
(
tl
.
rsqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
))
b_q
=
b_q
*
scale
# [BV, BK]
if
not
IS_KDA
:
b_h
*=
tl
.
exp
(
b_g
)
else
:
b_h
*=
tl
.
exp
(
b_g
[
None
,
:])
# [BV]
b_v
-=
tl
.
sum
(
b_h
*
b_k
[
None
,
:],
1
)
b_v
*=
b_beta
# [BV, BK]
b_h
+=
b_v
[:,
None
]
*
b_k
[
None
,
:]
# [BV]
b_o
=
tl
.
sum
(
b_h
*
b_q
[
None
,
:],
1
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# keep the states for multi-query tokens
if
INPLACE_FINAL_STATE
:
# Load state index and check for PAD_SLOT_ID (-1)
final_state_idx
=
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
# Only store if state index is valid (not PAD_SLOT_ID)
if
final_state_idx
>=
0
:
p_ht
=
ht
+
final_state_idx
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
mask_h
)
else
:
p_ht
=
ht
+
(
bos
+
i_t
)
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
mask_h
)
# Update pointers for next timestep
p_q
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_b
+=
HV
p_a
+=
HV
def
fused_sigmoid_gating_delta_rule_update
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
float
=
1.0
,
threshold
:
float
=
20.0
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
ssm_state_indices
:
torch
.
Tensor
|
None
=
None
,
num_accepted_tokens
:
torch
.
Tensor
|
None
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
is_kda
:
bool
=
False
,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating
computation and the recurrent delta rule update for better performance.
"""
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
triton
.
next_power_of_2
(
K
),
min
(
triton
.
next_power_of_2
(
V
),
32
)
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
4
if
cu_seqlens
is
not
None
and
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
"
f
" when using `cu_seqlens`. Please flatten variable-length"
f
" inputs before processing."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
if
inplace_final_state
:
final_state
=
initial_state
else
:
final_state
=
q
.
new_empty
(
T
,
HV
,
V
,
K
,
dtype
=
initial_state
.
dtype
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
final_state
.
stride
(
0
)
if
ssm_state_indices
is
None
:
stride_indices_seq
,
stride_indices_tok
=
1
,
1
elif
ssm_state_indices
.
ndim
==
1
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
(
0
),
1
else
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
()
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_sigmoid_gating_delta_rule_update_kernel
[
grid
](
A_log
=
A_log
,
a
=
a
.
contiguous
(),
b
=
b
.
contiguous
(),
dt_bias
=
dt_bias
,
beta
=
beta
,
threshold
=
threshold
,
q
=
q
.
contiguous
(),
k
=
k
.
contiguous
(),
v
=
v
.
contiguous
(),
o
=
o
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
scale
=
scale
,
N
=
N
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
stride_init_state_token
=
stride_init_state_token
,
stride_final_state_token
=
stride_final_state_token
,
stride_indices_seq
=
stride_indices_seq
,
stride_indices_tok
=
stride_indices_tok
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
IS_KDA
=
is_kda
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
o
=
o
.
squeeze
(
0
)
return
o
,
final_state
vllm/model_executor/models/qwen3_next.py
View file @
dc6b5784
...
@@ -34,7 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
...
@@ -34,7 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule
as
fla_chunk_gated_delta_rule
,
chunk_gated_delta_rule
as
fla_chunk_gated_delta_rule
,
)
)
from
vllm.model_executor.layers.fla.ops
import
(
from
vllm.model_executor.layers.fla.ops
import
(
fused_
recurrent
_gat
ed
_delta_rule
,
fused_
sigmoid
_gat
ing
_delta_rule
_update
,
)
)
from
vllm.model_executor.layers.fla.ops.chunk
import
l2norm_fwd
from
vllm.model_executor.layers.fla.ops.chunk
import
l2norm_fwd
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
...
@@ -731,42 +731,41 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -731,42 +731,41 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec
mixed_qkv_non_spec
)
)
if
attn_metadata
.
num_prefills
>
0
:
g
,
beta
=
fused_gdn_gating
(
self
.
A_log
,
a
,
b
,
self
.
dt_bias
)
g
,
beta
=
fused_gdn_gating
(
self
.
A_log
,
a
,
b
,
self
.
dt_bias
)
if
spec_sequence_masks
is
not
None
:
if
spec_sequence_masks
is
not
None
:
if
attn_metadata
.
num_prefills
==
0
and
attn_metadata
.
num_decodes
==
0
:
g_spec
=
g
beta_spec
=
beta
g_non_spec
=
None
beta_non_spec
=
None
else
:
g_spec
=
g
.
index_select
(
1
,
spec_token_indx
)
beta_spec
=
beta
.
index_select
(
1
,
spec_token_indx
)
g_non_spec
=
g
.
index_select
(
1
,
non_spec_token_indx
)
g_non_spec
=
g
.
index_select
(
1
,
non_spec_token_indx
)
beta_non_spec
=
beta
.
index_select
(
1
,
non_spec_token_indx
)
beta_non_spec
=
beta
.
index_select
(
1
,
non_spec_token_indx
)
else
:
else
:
g_spec
=
None
beta_spec
=
None
g_non_spec
=
g
g_non_spec
=
g
beta_non_spec
=
beta
beta_non_spec
=
beta
else
:
g_non_spec
=
None
beta_non_spec
=
None
# 2. Recurrent attention
# 2. Recurrent attention
# 2.1: Process the multi-query part
# 2.1: Process the multi-query part
if
spec_sequence_masks
is
not
None
:
if
spec_sequence_masks
is
not
None
:
core_attn_out_spec
,
last_recurrent_state
=
fused_recurrent_gated_delta_rule
(
core_attn_out_spec
,
last_recurrent_state
=
(
fused_sigmoid_gating_delta_rule_update
(
A_log
=
self
.
A_log
,
a
=
a
,
b
=
b
,
dt_bias
=
self
.
dt_bias
,
q
=
query_spec
,
q
=
query_spec
,
k
=
key_spec
,
k
=
key_spec
,
v
=
value_spec
,
v
=
value_spec
,
g
=
g_spec
,
beta
=
beta_spec
,
initial_state
=
ssm_state
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
inplace_final_state
=
True
,
cu_seqlens
=
spec_query_start_loc
[:
attn_metadata
.
num_spec_decodes
+
1
],
cu_seqlens
=
spec_query_start_loc
[
:
attn_metadata
.
num_spec_decodes
+
1
],
ssm_state_indices
=
spec_state_indices_tensor
,
ssm_state_indices
=
spec_state_indices_tensor
,
num_accepted_tokens
=
num_accepted_tokens
,
num_accepted_tokens
=
num_accepted_tokens
,
use_qk_l2norm_in_kernel
=
True
,
use_qk_l2norm_in_kernel
=
True
,
)
)
)
else
:
else
:
core_attn_out_spec
,
last_recurrent_state
=
None
,
None
core_attn_out_spec
,
last_recurrent_state
=
None
,
None
...
@@ -794,12 +793,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -794,12 +793,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
)
elif
attn_metadata
.
num_decodes
>
0
:
elif
attn_metadata
.
num_decodes
>
0
:
core_attn_out_non_spec
,
last_recurrent_state
=
(
core_attn_out_non_spec
,
last_recurrent_state
=
(
fused_recurrent_gated_delta_rule
(
fused_sigmoid_gating_delta_rule_update
(
A_log
=
self
.
A_log
,
a
=
a
,
b
=
b
,
dt_bias
=
self
.
dt_bias
,
q
=
query_non_spec
,
q
=
query_non_spec
,
k
=
key_non_spec
,
k
=
key_non_spec
,
v
=
value_non_spec
,
v
=
value_non_spec
,
g
=
g_non_spec
,
beta
=
beta_non_spec
,
initial_state
=
ssm_state
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
inplace_final_state
=
True
,
cu_seqlens
=
non_spec_query_start_loc
[
cu_seqlens
=
non_spec_query_start_loc
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment