Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cafebef1
Unverified
Commit
cafebef1
authored
Oct 30, 2025
by
Even Zhou
Committed by
GitHub
Oct 30, 2025
Browse files
[NPU] bugfix for Qwen3-Next and performance update (#11969)
parent
73dfd2df
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
21 deletions
+68
-21
.github/workflows/release-docker-npu-nightly.yml
.github/workflows/release-docker-npu-nightly.yml
+1
-1
.github/workflows/release-docker-npu.yml
.github/workflows/release-docker-npu.yml
+1
-1
python/sglang/srt/layers/attention/fla/layernorm_gated.py
python/sglang/srt/layers/attention/fla/layernorm_gated.py
+7
-1
python/sglang/srt/layers/attention/mamba/mamba.py
python/sglang/srt/layers/attention/mamba/mamba.py
+20
-11
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+31
-6
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+7
-0
scripts/ci/npu_ci_install_dependency.sh
scripts/ci/npu_ci_install_dependency.sh
+1
-1
No files found.
.github/workflows/release-docker-npu-nightly.yml
View file @
cafebef1
...
...
@@ -73,6 +73,6 @@ jobs:
push
:
${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance
:
false
build-args
:
|
SGLANG_KERNEL_NPU_TAG=2025
0926
SGLANG_KERNEL_NPU_TAG=2025
1030
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
.github/workflows/release-docker-npu.yml
View file @
cafebef1
...
...
@@ -69,6 +69,6 @@ jobs:
push
:
${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance
:
false
build-args
:
|
SGLANG_KERNEL_NPU_TAG=2025
0926
SGLANG_KERNEL_NPU_TAG=2025
1030
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
python/sglang/srt/layers/attention/fla/layernorm_gated.py
View file @
cafebef1
...
...
@@ -12,7 +12,9 @@ import triton
import
triton.language
as
tl
from
einops
import
rearrange
from
sglang.srt.utils
import
device_context
from
sglang.srt.utils
import
device_context
,
is_npu
_is_npu
=
is_npu
()
def
rms_norm_ref
(
...
...
@@ -182,6 +184,10 @@ def _layer_norm_fwd(
return
out
,
mean
,
rstd
if
_is_npu
:
from
sgl_kernel_npu.fla.layernorm_gated
import
layer_norm_fwd_npu
as
_layer_norm_fwd
def
rms_norm_gated
(
*
,
x
,
...
...
python/sglang/srt/layers/attention/mamba/mamba.py
View file @
cafebef1
...
...
@@ -13,16 +13,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.distributed.utils
import
divide
from
sglang.srt.layers.attention.mamba.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
,
)
from
sglang.srt.layers.attention.mamba.causal_conv1d_triton
import
(
causal_conv1d_fn
as
causal_conv1d_fn_triton
,
)
from
sglang.srt.layers.attention.mamba.causal_conv1d_triton
import
(
causal_conv1d_update
as
causal_conv1d_update_triton
,
)
from
sglang.srt.layers.attention.mamba.mamba2_metadata
import
Mamba2Metadata
from
sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated
import
Mixer2RMSNormGated
from
sglang.srt.layers.attention.mamba.ops
import
(
...
...
@@ -40,7 +30,26 @@ from sglang.srt.model_loader.weight_utils import (
composed_weight_loader
,
sharded_weight_loader
,
)
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
is_cuda
,
is_npu
,
set_weight_attrs
if
is_cuda
():
from
sglang.srt.layers.attention.mamba.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
,
)
from
sglang.srt.layers.attention.mamba.causal_conv1d_triton
import
(
causal_conv1d_fn
as
causal_conv1d_fn_triton
,
)
from
sglang.srt.layers.attention.mamba.causal_conv1d_triton
import
(
causal_conv1d_update
as
causal_conv1d_update_triton
,
)
elif
is_npu
():
from
sgl_kernel_npu.mamba.causal_conv1d
import
(
causal_conv1d_fn_npu
as
causal_conv1d_fn
,
)
from
sgl_kernel_npu.mamba.causal_conv1d
import
(
causal_conv1d_update_npu
as
causal_conv1d_update
,
)
LoaderFunction
=
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]
...
...
python/sglang/srt/layers/moe/topk.py
View file @
cafebef1
...
...
@@ -314,16 +314,41 @@ class TopK(CustomOp):
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
global_num_experts
=
router_logits
.
shape
[
-
1
]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if
global_num_experts
==
256
:
use_grouped_topk
=
self
.
topk_config
.
use_grouped_topk
torch_native
=
self
.
topk_config
.
torch_native
renormalize
=
self
.
topk_config
.
renormalize
if
not
use_grouped_topk
and
not
torch_native
:
topk_weights
,
topk_ids
,
_
=
torch_npu
.
npu_moe_gating_top_k_softmax
(
router_logits
,
k
=
self
.
topk_config
.
top_k
,
)
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
if
renormalize
:
topk_weights_sum
=
(
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
if
self
.
topk_config
.
num_fused_shared_experts
==
0
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
if
expert_location_dispatch_info
is
not
None
:
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
get_global_expert_distribution_recorder
().
on_select_experts
(
topk_ids
=
topk_ids
)
return
StandardTopKOutput
(
topk_weights
,
topk_ids
,
_
)
if
use_grouped_topk
and
not
torch_native
and
router_logits
.
shape
[
-
1
]
==
256
:
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
routed_scaling_factor
=
self
.
topk_config
.
routed_scaling_factor
or
1
router_logits
=
router_logits
.
to
(
torch
.
float32
)
topk_weights
,
topk_ids
,
_
=
torch_npu
.
npu_moe_gating_top_k
(
router_logits
,
router_logits
.
to
(
torch
.
float32
)
,
k
=
self
.
topk_config
.
top_k
,
bias
=
self
.
topk_config
.
correction_bias
.
to
(
torch
.
float32
),
k_group
=
self
.
topk_config
.
topk_group
,
...
...
@@ -335,7 +360,7 @@ class TopK(CustomOp):
eps
=
float
(
1e-20
),
)
if
self
.
topk_config
.
renormalize
:
if
renormalize
:
topk_weights_sum
=
(
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
if
self
.
topk_config
.
num_fused_shared_experts
==
0
...
...
python/sglang/srt/models/qwen3_next.py
View file @
cafebef1
...
...
@@ -478,6 +478,13 @@ class Qwen3GatedDeltaNet(nn.Module):
# reshape input data into 2D tensor
core_attn_out
=
core_attn_out
.
reshape
(
-
1
,
core_attn_out
.
shape
[
-
1
])
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
# Add padding for DP-Attn
if
is_dp_attention_enabled
():
core_attn_out_pad
=
torch
.
zeros_like
(
z
)
core_attn_out_pad
[:
core_attn_out
.
shape
[
0
],
:]
=
core_attn_out
core_attn_out
=
core_attn_out_pad
core_attn_out
=
self
.
norm
(
core_attn_out
,
z
)
core_attn_out
=
core_attn_out
.
reshape
(
z_shape_og
)
core_attn_out
=
core_attn_out
.
reshape
(
*
core_attn_out
.
shape
[:
-
2
],
-
1
)
...
...
scripts/ci/npu_ci_install_dependency.sh
View file @
cafebef1
...
...
@@ -59,7 +59,7 @@ wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./
### Install sgl-kernel-npu
SGL_KERNEL_NPU_TAG
=
"2025
0926
"
SGL_KERNEL_NPU_TAG
=
"2025
1030
"
git clone
--depth
1 https://github.com/sgl-project/sgl-kernel-npu.git
--branch
${
SGL_KERNEL_NPU_TAG
}
# pin wheel to 0.45.1 ref: https://github.com/pypa/wheel/issues/662
pip
install
wheel
==
0.45.1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment