Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5b64f006
Unverified
Commit
5b64f006
authored
Sep 11, 2025
by
Even Zhou
Committed by
GitHub
Sep 10, 2025
Browse files
[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)
parent
5b7448de
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
317 additions
and
109 deletions
+317
-109
.github/workflows/pr-test-npu.yml
.github/workflows/pr-test-npu.yml
+36
-0
.github/workflows/release-docker-npu-nightly.yml
.github/workflows/release-docker-npu-nightly.yml
+1
-0
.github/workflows/release-docker-npu.yml
.github/workflows/release-docker-npu.yml
+1
-3
python/sglang/srt/eplb/eplb_manager.py
python/sglang/srt/eplb/eplb_manager.py
+2
-2
python/sglang/srt/eplb/expert_distribution.py
python/sglang/srt/eplb/expert_distribution.py
+12
-4
python/sglang/srt/eplb/expert_location_updater.py
python/sglang/srt/eplb/expert_location_updater.py
+1
-1
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+10
-3
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+108
-48
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
+0
-2
python/sglang/srt/layers/moe/token_dispatcher/base.py
python/sglang/srt/layers/moe/token_dispatcher/base.py
+0
-11
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+8
-35
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+8
-0
scripts/ci/npu_ci_install_dependency.sh
scripts/ci/npu_ci_install_dependency.sh
+6
-0
test/srt/ascend/test_ascend_deepep.py
test/srt/ascend/test_ascend_deepep.py
+121
-0
test/srt/run_suite.py
test/srt/run_suite.py
+3
-0
No files found.
.github/workflows/pr-test-npu.yml
View file @
5b64f006
...
@@ -127,12 +127,48 @@ jobs:
...
@@ -127,12 +127,48 @@ jobs:
cd test/srt
cd test/srt
python3 run_suite.py --suite per-commit-4-ascend-npu --timeout-per-file 3600
python3 run_suite.py --suite per-commit-4-ascend-npu --timeout-per-file 3600
per-commit-16-ascend-a3
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft ==
false
runs-on
:
linux-aarch64-a3-16
container
:
image
:
swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v4
-
name
:
Install dependencies
run
:
|
# speed up by using infra cache services
CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local"
sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list
pip config set global.index-url http://${CACHING_URL}/pypi/simple
pip config set global.trusted-host ${CACHING_URL}
bash scripts/ci/npu_ci_install_dependency.sh
# copy required file from our daily cache
cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp
# copy download through proxy
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
-
name
:
Run test
timeout-minutes
:
90
env
:
SGLANG_USE_MODELSCOPE
:
true
SGLANG_IS_IN_CI
:
true
HF_ENDPOINT
:
https://hf-mirror.com
TORCH_EXTENSIONS_DIR
:
/tmp/torch_extensions
run
:
|
cd test/srt
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 5400
pr-test-npu-finish
:
pr-test-npu-finish
:
if
:
always()
if
:
always()
needs
:
needs
:
-
per-commit-1-ascend-npu
-
per-commit-1-ascend-npu
-
per-commit-2-ascend-npu
-
per-commit-2-ascend-npu
-
per-commit-4-ascend-npu
-
per-commit-4-ascend-npu
-
per-commit-16-ascend-a3
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
steps
:
steps
:
-
name
:
Check all dependent job statuses
-
name
:
Check all dependent job statuses
...
...
.github/workflows/release-docker-npu-nightly.yml
View file @
5b64f006
...
@@ -72,5 +72,6 @@ jobs:
...
@@ -72,5 +72,6 @@ jobs:
push
:
${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
push
:
${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance
:
false
provenance
:
false
build-args
:
|
build-args
:
|
SGLANG_KERNEL_NPU_TAG=20250901
CANN_VERSION=${{ matrix.cann_version }}
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
DEVICE_TYPE=${{ matrix.device_type }}
.github/workflows/release-docker-npu.yml
View file @
5b64f006
...
@@ -54,8 +54,6 @@ jobs:
...
@@ -54,8 +54,6 @@ jobs:
run
:
|
run
:
|
version=$(cat python/sglang/version.py | cut -d'"' -f2)
version=$(cat python/sglang/version.py | cut -d'"' -f2)
echo "TAG=lmsysorg/sglang:v$version-cann${{ matrix.cann_version }}-${{ matrix.device_type }}" >> $GITHUB_OUTPUT
echo "TAG=lmsysorg/sglang:v$version-cann${{ matrix.cann_version }}-${{ matrix.device_type }}" >> $GITHUB_OUTPUT
kernel_tag=$(curl -s https://api.github.com/repos/sgl-project/sgl-kernel-npu/tags | jq -r '.[0].name')
echo "KERNEL_NPU_TAG=${kernel_tag}" >> $GITHUB_OUTPUT
-
name
:
Build and push Docker image
-
name
:
Build and push Docker image
id
:
build-and-push
id
:
build-and-push
...
@@ -70,6 +68,6 @@ jobs:
...
@@ -70,6 +68,6 @@ jobs:
push
:
${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
push
:
${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance
:
false
provenance
:
false
build-args
:
|
build-args
:
|
SGLANG_KERNEL_NPU_TAG=
${{ steps.get_version.outputs.KERNEL_NPU_TAG }}
SGLANG_KERNEL_NPU_TAG=
20250901
CANN_VERSION=${{ matrix.cann_version }}
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
DEVICE_TYPE=${{ matrix.device_type }}
python/sglang/srt/eplb/eplb_manager.py
View file @
5b64f006
...
@@ -55,7 +55,7 @@ class EPLBManager:
...
@@ -55,7 +55,7 @@ class EPLBManager:
enable_timing
=
self
.
_rebalance_layers_per_chunk
is
None
enable_timing
=
self
.
_rebalance_layers_per_chunk
is
None
if
enable_timing
:
if
enable_timing
:
torch
.
cuda
.
synchronize
()
torch
.
get_device_module
()
.
synchronize
()
time_start
=
time
.
time
()
time_start
=
time
.
time
()
dump_record_output
=
get_global_expert_distribution_recorder
().
dump_record
(
dump_record_output
=
get_global_expert_distribution_recorder
().
dump_record
(
...
@@ -85,7 +85,7 @@ class EPLBManager:
...
@@ -85,7 +85,7 @@ class EPLBManager:
msg
=
f
"[EPLBManager] rebalance end"
msg
=
f
"[EPLBManager] rebalance end"
if
enable_timing
:
if
enable_timing
:
torch
.
cuda
.
synchronize
()
torch
.
get_device_module
()
.
synchronize
()
time_end
=
time
.
time
()
time_end
=
time
.
time
()
msg
+=
f
" time=
{
time_end
-
time_start
:.
3
f
}
s"
msg
+=
f
" time=
{
time_end
-
time_start
:.
3
f
}
s"
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
...
python/sglang/srt/eplb/expert_distribution.py
View file @
5b64f006
...
@@ -30,7 +30,9 @@ import torch.distributed
...
@@ -30,7 +30,9 @@ import torch.distributed
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
Withable
,
get_bool_env_var
from
sglang.srt.utils
import
Withable
,
get_bool_env_var
,
is_npu
_is_npu
=
is_npu
()
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.eplb.expert_location
import
ExpertLocationMetadata
from
sglang.srt.eplb.expert_location
import
ExpertLocationMetadata
...
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
...
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def
_on_hook
(
self
,
hook_name
:
str
,
**
kwargs
):
def
_on_hook
(
self
,
hook_name
:
str
,
**
kwargs
):
if
self
.
_disable_all
:
if
self
.
_disable_all
:
return
return
if
not
(
self
.
_recording
or
torch
.
cuda
.
is_current_stream_capturing
()):
if
not
(
self
.
_recording
or
torch
.
get_device_module
().
is_current_stream_capturing
()
):
return
return
gatherer
=
self
.
_single_pass_gatherers
[
gatherer
=
self
.
_single_pass_gatherers
[
self
.
_accumulator
.
get_single_pass_gatherer_key
(
self
.
_accumulator
.
get_single_pass_gatherer_key
(
...
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
...
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
class
_LayerBasedGpuSinglePassGatherer
(
_SinglePassGatherer
):
class
_LayerBasedGpuSinglePassGatherer
(
_SinglePassGatherer
):
def
__init__
(
self
,
*
args
,
enable_global_physical_experts
:
bool
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
enable_global_physical_experts
:
bool
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
if
not
_is_npu
:
device
=
"cuda"
else
:
device
=
"npu"
self
.
_enable_global_physical_experts
=
enable_global_physical_experts
self
.
_enable_global_physical_experts
=
enable_global_physical_experts
self
.
_data
=
torch
.
zeros
(
self
.
_data
=
torch
.
zeros
(
(
(
...
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
...
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
),
),
),
),
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
device
=
device
,
)
)
def
reset
(
self
):
def
reset
(
self
):
...
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
...
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
if
self
.
_first_dump
:
if
self
.
_first_dump
:
self
.
_first_dump
=
False
self
.
_first_dump
=
False
torch
.
cuda
.
empty_cache
()
torch
.
get_device_module
()
.
empty_cache
()
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
logical_count_of_buffered_step
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
logical_count_of_buffered_step
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
...
...
python/sglang/srt/eplb/expert_location_updater.py
View file @
5b64f006
...
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
...
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
):
):
if
self
.
_first_execution
:
if
self
.
_first_execution
:
self
.
_first_execution
=
False
self
.
_first_execution
=
False
torch
.
cuda
.
empty_cache
()
torch
.
get_device_module
()
.
empty_cache
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
old_expert_location_metadata
=
get_global_expert_location_metadata
()
assert
old_expert_location_metadata
is
not
None
assert
old_expert_location_metadata
is
not
None
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
5b64f006
...
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
...
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
...
@@ -33,6 +34,7 @@ class ForwardMetadata:
...
@@ -33,6 +34,7 @@ class ForwardMetadata:
extend_seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
seq_lens_list_cumsum
:
Optional
[
List
[
int
]]
=
None
class
AscendAttnBackend
(
AttentionBackend
):
class
AscendAttnBackend
(
AttentionBackend
):
...
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
tp_size
=
get_attention_tp_size
()
self
.
forward_metadata
=
ForwardMetadata
()
self
.
forward_metadata
=
ForwardMetadata
()
self
.
forward_metadata
.
block_tables
=
(
self
.
forward_metadata
.
block_tables
=
(
...
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
forward_batch
.
extend_seq_lens
.
cpu
().
int
()
forward_batch
.
extend_seq_lens
.
cpu
().
int
()
)
)
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
)
if
forward_batch
.
is_extend_in_batch
:
seq_lens_list_cumsum
[
-
1
]
=
(
(
seq_lens_list_cumsum
[
-
1
]
-
1
)
//
tp_size
+
1
)
*
tp_size
self
.
forward_metadata
.
seq_lens_list_cumsum
=
seq_lens_list_cumsum
self
.
graph_mode
=
False
self
.
graph_mode
=
False
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
5b64f006
...
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
...
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.token_dispatcher
import
(
AscendDeepEPLLOutput
,
DeepEPLLOutput
,
DeepEPLLOutput
,
DeepEPNormalOutput
,
DeepEPNormalOutput
,
DispatchOutput
,
DispatchOutput
,
...
@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
...
@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return
self
.
forward_aiter
(
dispatch_output
)
return
self
.
forward_aiter
(
dispatch_output
)
if
_is_npu
:
if
_is_npu
:
assert
DispatchOutputChecker
.
format_is_
ascent_ll
(
dispatch_output
)
assert
DispatchOutputChecker
.
format_is_
deepep
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
...
@@ -718,63 +717,124 @@ class DeepEPMoE(EPMoE):
...
@@ -718,63 +717,124 @@ class DeepEPMoE(EPMoE):
def
forward_npu
(
def
forward_npu
(
self
,
self
,
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
]
,
):
):
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
AscendDeepEPLLOutput
)
hidden_states
,
topk_idx
,
topk_weights
,
_
,
seg_indptr
,
_
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
self
.
moe_runner_config
.
activation
==
"silu"
import
torch_npu
from
sglang.srt.layers.moe.token_dispatcher
import
DispatchOutputChecker
# NOTE: Ascend's Dispatch & Combine does not support FP16
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype
=
torch
.
bfloat16
output_dtype
=
torch
.
bfloat16
group_list_type
=
1
pertoken_scale
=
hidden_states
[
1
]
def
_forward_normal
(
dispatch_output
:
DeepEPNormalOutput
):
hidden_states
=
hidden_states
[
0
]
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
DeepEPNormalOutput
)
hidden_states
,
_
,
_
,
num_recv_tokens_per_expert
=
dispatch_output
if
isinstance
(
hidden_states
,
tuple
):
per_token_scale
=
hidden_states
[
1
]
hidden_states
=
hidden_states
[
0
]
else
:
# dynamic quant
hidden_states
,
per_token_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
group_list_type
=
1
group_list
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int64
).
to
(
seg_indptr
=
seg_indptr
.
to
(
torch
.
int64
)
hidden_states
.
device
)
import
torch_npu
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
scale
=
[
self
.
w13_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
per_token_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
# act_fn: swiglu
hidden_states
=
torch_npu
.
npu_swiglu
(
hidden_states
)
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
# gmm1: gate_up_proj
return
hidden_states
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
seg_indptr
,
output_dtype
=
torch
.
int32
,
)[
0
]
# act_fn: swiglu
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
x
=
hidden_states
,
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
activation_scale
=
pertoken_scale
,
bias
=
None
,
quant_scale
=
None
,
quant_offset
=
None
,
group_index
=
seg_indptr
,
activate_left
=
True
,
quant_mode
=
1
,
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
seg_indptr
,
output_dtype
=
output_dtype
,
)[
0
]
return
hidden_states
def
_forward_ll
(
dispatch_output
:
DeepEPLLOutput
):
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
DeepEPLLOutput
)
hidden_states
,
topk_idx
,
topk_weights
,
group_list
,
_
=
dispatch_output
per_token_scale
=
hidden_states
[
1
]
hidden_states
=
hidden_states
[
0
]
group_list
=
group_list
.
to
(
torch
.
int64
)
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
torch
.
int32
,
)[
0
]
# act_fn: swiglu
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
x
=
hidden_states
,
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
activation_scale
=
per_token_scale
,
bias
=
None
,
quant_scale
=
None
,
quant_offset
=
None
,
group_index
=
group_list
,
activate_left
=
True
,
quant_mode
=
1
,
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
return
hidden_states
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
return
_forward_normal
(
dispatch_output
)
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
return
_forward_ll
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Not Supported DeepEP format
{
dispatch_output
.
format
}
"
)
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
...
...
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
View file @
5b64f006
...
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
...
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutputFormat
,
DispatchOutputFormat
,
)
)
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
(
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
(
AscendDeepEPLLOutput
,
DeepEPConfig
,
DeepEPConfig
,
DeepEPDispatcher
,
DeepEPDispatcher
,
DeepEPLLCombineInput
,
DeepEPLLCombineInput
,
...
@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
...
@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
)
)
__all__
=
[
__all__
=
[
"AscendDeepEPLLOutput"
,
"BaseDispatcher"
,
"BaseDispatcher"
,
"BaseDispatcherConfig"
,
"BaseDispatcherConfig"
,
"CombineInput"
,
"CombineInput"
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/base.py
View file @
5b64f006
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.token_dispatcher
import
(
AscendDeepEPLLOutput
,
DeepEPLLCombineInput
,
DeepEPLLCombineInput
,
DeepEPLLOutput
,
DeepEPLLOutput
,
DeepEPNormalCombineInput
,
DeepEPNormalCombineInput
,
...
@@ -47,19 +46,12 @@ class DispatchOutputChecker:
...
@@ -47,19 +46,12 @@ class DispatchOutputChecker:
)
->
TypeGuard
[
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
]]:
)
->
TypeGuard
[
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
]]:
return
dispatch_output
.
format
.
is_deepep
()
return
dispatch_output
.
format
.
is_deepep
()
@
staticmethod
def
format_is_ascent_ll
(
dispatch_output
:
DispatchOutput
,
)
->
TypeGuard
[
AscendDeepEPLLOutput
]:
return
dispatch_output
.
format
.
is_ascent_ll
()
class
DispatchOutputFormat
(
Enum
):
class
DispatchOutputFormat
(
Enum
):
STANDARD
=
"standard"
STANDARD
=
"standard"
DEEPEP_NORMAL
=
"deepep_normal"
DEEPEP_NORMAL
=
"deepep_normal"
DEEPEP_LL
=
"deepep_ll"
DEEPEP_LL
=
"deepep_ll"
ASCENT_LL
=
"ascent_ll"
def
is_standard
(
self
)
->
bool
:
def
is_standard
(
self
)
->
bool
:
return
self
==
DispatchOutputFormat
.
STANDARD
return
self
==
DispatchOutputFormat
.
STANDARD
...
@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
...
@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
DispatchOutputFormat
.
DEEPEP_LL
,
DispatchOutputFormat
.
DEEPEP_LL
,
]
]
def
is_ascent_ll
(
self
)
->
bool
:
return
self
==
DispatchOutputFormat
.
ASCENT_LL
@
runtime_checkable
@
runtime_checkable
class
DispatchOutput
(
Protocol
):
class
DispatchOutput
(
Protocol
):
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
5b64f006
...
@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
...
@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
return
DispatchOutputFormat
.
DEEPEP_LL
return
DispatchOutputFormat
.
DEEPEP_LL
class
AscendDeepEPLLOutput
(
NamedTuple
):
"""AscendDeepEP low latency dispatch output."""
hidden_states_fp8
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
topk_idx
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
seg_indptr
:
torch
.
Tensor
expected_m
:
int
@
property
def
format
(
self
)
->
DispatchOutputFormat
:
return
DispatchOutputFormat
.
ASCENT_LL
assert
isinstance
(
DeepEPNormalOutput
,
DispatchOutput
)
assert
isinstance
(
DeepEPNormalOutput
,
DispatchOutput
)
assert
isinstance
(
DeepEPLLOutput
,
DispatchOutput
)
assert
isinstance
(
DeepEPLLOutput
,
DispatchOutput
)
assert
isinstance
(
AscendDeepEPLLOutput
,
DispatchOutput
)
class
DeepEPNormalCombineInput
(
NamedTuple
):
class
DeepEPNormalCombineInput
(
NamedTuple
):
...
@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
deepep_post_reorder_triton_kernel
,
deepep_post_reorder_triton_kernel
,
)
)
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
output
=
hidden_states
output
=
hidden_states
else
:
else
:
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
...
@@ -553,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -553,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
masked_m
)
)
if
_is_npu
:
deepep_output
=
DeepEPLLOutput
(
deepep_output
=
AscendDeepEPLLOutput
(
hidden_states
,
hidden_states
,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
self
.
handle
[
1
],
)
expected_m
,
)
else
:
deepep_output
=
DeepEPLLOutput
(
hidden_states
,
topk_idx
,
topk_weights
,
masked_m
,
expected_m
,
)
return
deepep_output
return
deepep_output
def
_dispatch_core
(
def
_dispatch_core
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
5b64f006
...
@@ -330,6 +330,14 @@ class TopK(CustomOp):
...
@@ -330,6 +330,14 @@ class TopK(CustomOp):
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
topk_weights
=
topk_weights
/
topk_weights_sum
if
expert_location_dispatch_info
is
not
None
:
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
get_global_expert_distribution_recorder
().
on_select_experts
(
topk_ids
=
topk_ids
)
return
StandardTopKOutput
(
topk_weights
,
topk_ids
,
_
)
return
StandardTopKOutput
(
topk_weights
,
topk_ids
,
_
)
else
:
else
:
self
.
topk_config
.
torch_native
=
True
self
.
topk_config
.
torch_native
=
True
...
...
scripts/ci/npu_ci_install_dependency.sh
View file @
5b64f006
...
@@ -51,5 +51,11 @@ ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil
...
@@ -51,5 +51,11 @@ ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil
wget
-O
"
${
TRITON_ASCEND_NAME
}
"
"
${
TRITON_ASCEND_URL
}
"
&&
${
PIP_INSTALL
}
"./
${
TRITON_ASCEND_NAME
}
"
wget
-O
"
${
TRITON_ASCEND_NAME
}
"
"
${
TRITON_ASCEND_URL
}
"
&&
${
PIP_INSTALL
}
"./
${
TRITON_ASCEND_NAME
}
"
### Install sgl-kernel-npu
SGL_KERNEL_NPU_TAG
=
"20250901"
git clone
--depth
1 https://github.com/sgl-project/sgl-kernel-npu.git
--branch
${
SGL_KERNEL_NPU_TAG
}
(
cd
sgl-kernel-npu
&&
bash ./build.sh
-a
deepep
&&
pip
install
output/deep_ep
*
.whl
&&
cd
"
$(
pip show deep-ep |
grep
-E
'^Location:'
|
awk
'{print $2}'
)
"
&&
ln
-s
deep_ep/deep_ep_cpp
*
.so
)
### Install SGLang
### Install SGLang
${
PIP_INSTALL
}
-v
-e
"python[srt_npu]"
${
PIP_INSTALL
}
-v
-e
"python[srt_npu]"
test/srt/ascend/test_ascend_deepep.py
0 → 100644
View file @
5b64f006
import
os
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8"
:
{
"accuracy"
:
0.95
,
"latency"
:
1000
,
"output_throughput"
:
6
,
},
}
class
TestAscendDeepEP
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--attention-backend"
,
"ascend"
,
"--quantization"
,
"w8a8_int8"
,
"--mem-fraction-static"
,
0.9
,
"--max-running-requests"
,
32
,
"--disable-radix-cache"
,
"--chunked-prefill-size"
,
32768
,
"--disable-cuda-graph"
,
"--tp-size"
,
16
,
"--dp-size"
,
1
,
"--ep-size"
,
16
,
"--moe-a2a-backend"
,
"deepep"
,
"--deepep-mode"
,
"auto"
,
]
cls
.
extra_envs
=
{
"HCCL_BUFFSIZE"
:
"500"
,
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
:
"32"
,
}
os
.
environ
.
update
(
cls
.
extra_envs
)
def
test_a_gsm8k
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
1500
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
5b64f006
...
@@ -300,6 +300,9 @@ suite_ascend = {
...
@@ -300,6 +300,9 @@ suite_ascend = {
TestFile
(
"ascend/test_ascend_mla_w8a8int8.py"
,
400
),
TestFile
(
"ascend/test_ascend_mla_w8a8int8.py"
,
400
),
TestFile
(
"ascend/test_ascend_tp4_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_tp4_bf16.py"
,
400
),
],
],
"per-commit-16-ascend-a3"
:
[
TestFile
(
"ascend/test_ascend_deepep.py"
,
400
),
],
}
}
suites
.
update
(
suite_amd
)
suites
.
update
(
suite_amd
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment