Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b6dd4bcb
Unverified
Commit
b6dd4bcb
authored
Sep 16, 2025
by
cao1zhg
Committed by
GitHub
Sep 16, 2025
Browse files
feat: update support for qwen3next model (#10466)
parent
b2435be6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
7 deletions
+11
-7
python/sglang/srt/layers/attention/fla/fused_recurrent.py
python/sglang/srt/layers/attention/fla/fused_recurrent.py
+4
-4
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
...rt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
+2
-2
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+5
-1
No files found.
python/sglang/srt/layers/attention/fla/fused_recurrent.py
View file @
b6dd4bcb
...
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
...
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
)
+
1e-6
)
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
)
+
1e-6
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
)
b_q
=
b_q
*
scale
b_q
=
b_q
*
scale
# [BK, BV]
# [BK, BV]
b_h
*=
exp
(
b_g
)
b_h
*=
exp
(
b_g
)
...
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
...
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
)
+
1e-6
)
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
)
+
1e-6
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
)
b_q
=
b_q
*
scale
b_q
=
b_q
*
scale
# [BK, BV]
# [BK, BV]
b_h
*=
exp
(
b_g
)
b_h
*=
exp
(
b_g
)
...
...
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
View file @
b6dd4bcb
...
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
...
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
# Apply L2 normalization if enabled
# Apply L2 normalization if enabled
if
USE_QK_L2NORM_IN_KERNEL
:
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
)
+
1e-6
)
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
)
+
1e-6
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
)
b_q
=
b_q
*
scale
b_q
=
b_q
*
scale
...
...
python/sglang/srt/models/qwen3_next.py
View file @
b6dd4bcb
...
@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
...
@@ -239,6 +239,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self
,
self
,
config
:
Qwen3NextConfig
,
config
:
Qwen3NextConfig
,
layer_id
:
int
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
...
@@ -278,6 +279,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_size
=
projection_size_qkvz
,
output_size
=
projection_size_qkvz
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
tp_rank
=
self
.
attn_tp_rank
,
tp_rank
=
self
.
attn_tp_rank
,
tp_size
=
self
.
attn_tp_size
,
tp_size
=
self
.
attn_tp_size
,
)
)
...
@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
...
@@ -285,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_size
=
projection_size_ba
,
output_size
=
projection_size_ba
,
bias
=
False
,
bias
=
False
,
quant_config
=
None
,
tp_rank
=
self
.
attn_tp_rank
,
tp_rank
=
self
.
attn_tp_rank
,
tp_size
=
self
.
attn_tp_size
,
tp_size
=
self
.
attn_tp_size
,
)
)
...
@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
...
@@ -336,6 +339,7 @@ class Qwen3GatedDeltaNet(nn.Module):
self
.
value_dim
,
self
.
value_dim
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
reduce_results
=
False
,
reduce_results
=
False
,
tp_rank
=
self
.
attn_tp_rank
,
tp_rank
=
self
.
attn_tp_rank
,
...
@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
...
@@ -493,7 +497,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_attn
=
Qwen3GatedDeltaNet
(
config
,
layer_id
,
alt_stream
)
self
.
linear_attn
=
Qwen3GatedDeltaNet
(
config
,
layer_id
,
quant_config
,
alt_stream
)
# Qwen3Next all layers are sparse and have no nextn now
# Qwen3Next all layers are sparse and have no nextn now
self
.
is_layer_sparse
=
True
self
.
is_layer_sparse
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment