Unverified Commit 5b64f006 authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)

parent 5b7448de
......@@ -127,12 +127,48 @@ jobs:
cd test/srt
python3 run_suite.py --suite per-commit-4-ascend-npu --timeout-per-file 3600
per-commit-16-ascend-a3:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
runs-on: linux-aarch64-a3-16
container:
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
run: |
# speed up by using infra cache services
CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local"
sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list
pip config set global.index-url http://${CACHING_URL}/pypi/simple
pip config set global.trusted-host ${CACHING_URL}
bash scripts/ci/npu_ci_install_dependency.sh
# copy required file from our daily cache
cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp
# copy download through proxy
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test
timeout-minutes: 90
env:
SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true
HF_ENDPOINT: https://hf-mirror.com
TORCH_EXTENSIONS_DIR: /tmp/torch_extensions
run: |
cd test/srt
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 5400
pr-test-npu-finish:
if: always()
needs:
- per-commit-1-ascend-npu
- per-commit-2-ascend-npu
- per-commit-4-ascend-npu
- per-commit-16-ascend-a3
runs-on: ubuntu-latest
steps:
- name: Check all dependent job statuses
......
......@@ -72,5 +72,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false
build-args: |
SGLANG_KERNEL_NPU_TAG=20250901
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
......@@ -54,8 +54,6 @@ jobs:
run: |
version=$(cat python/sglang/version.py | cut -d'"' -f2)
echo "TAG=lmsysorg/sglang:v$version-cann${{ matrix.cann_version }}-${{ matrix.device_type }}" >> $GITHUB_OUTPUT
kernel_tag=$(curl -s https://api.github.com/repos/sgl-project/sgl-kernel-npu/tags | jq -r '.[0].name')
echo "KERNEL_NPU_TAG=${kernel_tag}" >> $GITHUB_OUTPUT
- name: Build and push Docker image
id: build-and-push
......@@ -70,6 +68,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false
build-args: |
SGLANG_KERNEL_NPU_TAG=${{ steps.get_version.outputs.KERNEL_NPU_TAG }}
SGLANG_KERNEL_NPU_TAG=20250901
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
......@@ -55,7 +55,7 @@ class EPLBManager:
enable_timing = self._rebalance_layers_per_chunk is None
if enable_timing:
torch.cuda.synchronize()
torch.get_device_module().synchronize()
time_start = time.time()
dump_record_output = get_global_expert_distribution_recorder().dump_record(
......@@ -85,7 +85,7 @@ class EPLBManager:
msg = f"[EPLBManager] rebalance end"
if enable_timing:
torch.cuda.synchronize()
torch.get_device_module().synchronize()
time_end = time.time()
msg += f" time={time_end - time_start:.3f}s"
logger.info(msg)
......
......@@ -30,7 +30,9 @@ import torch.distributed
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
......@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()):
if not (
self._recording or torch.get_device_module().is_current_stream_capturing()
):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
......@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
if not _is_npu:
device = "cuda"
else:
device = "npu"
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
......@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
),
),
dtype=torch.int,
device="cuda",
device=device,
)
def reset(self):
......@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
if self._first_dump:
self._first_dump = False
torch.cuda.empty_cache()
torch.get_device_module().empty_cache()
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
......
......@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
):
if self._first_execution:
self._first_execution = False
torch.cuda.empty_cache()
torch.get_device_module().empty_cache()
old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None
......
......@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var
......@@ -33,6 +34,7 @@ class ForwardMetadata:
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
class AscendAttnBackend(AttentionBackend):
......@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
self.forward_metadata = ForwardMetadata()
self.forward_metadata.block_tables = (
......@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
forward_batch.extend_seq_lens_cpu
)
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False
......
......@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
......@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
......@@ -718,24 +717,78 @@ class DeepEPMoE(EPMoE):
def forward_npu(
self,
dispatch_output: DeepEPLLOutput,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
):
if TYPE_CHECKING:
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
import torch_npu
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16
group_list_type = 1
def _forward_normal(dispatch_output: DeepEPNormalOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPNormalOutput)
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
pertoken_scale = hidden_states[1]
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
else:
# dynamic quant
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
group_list_type = 1
seg_indptr = seg_indptr.to(torch.int64)
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
import torch_npu
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
def _forward_ll(dispatch_output: DeepEPLLOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
......@@ -744,7 +797,7 @@ class DeepEPMoE(EPMoE):
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
group_list=group_list,
output_dtype=torch.int32,
)[0]
......@@ -752,11 +805,11 @@ class DeepEPMoE(EPMoE):
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=pertoken_scale,
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=seg_indptr,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
......@@ -770,12 +823,19 @@ class DeepEPMoE(EPMoE):
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
return _forward_normal(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
return _forward_ll(dispatch_output)
else:
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
if get_moe_a2a_backend().is_deepep():
......
......@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLCombineInput,
......@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
)
__all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher",
"BaseDispatcherConfig",
"CombineInput",
......
......@@ -8,7 +8,6 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLCombineInput,
DeepEPLLOutput,
DeepEPNormalCombineInput,
......@@ -47,19 +46,12 @@ class DispatchOutputChecker:
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
return dispatch_output.format.is_deepep()
@staticmethod
def format_is_ascent_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[AscendDeepEPLLOutput]:
return dispatch_output.format.is_ascent_ll()
class DispatchOutputFormat(Enum):
STANDARD = "standard"
DEEPEP_NORMAL = "deepep_normal"
DEEPEP_LL = "deepep_ll"
ASCENT_LL = "ascent_ll"
def is_standard(self) -> bool:
return self == DispatchOutputFormat.STANDARD
......@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
DispatchOutputFormat.DEEPEP_LL,
]
def is_ascent_ll(self) -> bool:
return self == DispatchOutputFormat.ASCENT_LL
@runtime_checkable
class DispatchOutput(Protocol):
......
......@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
return DispatchOutputFormat.DEEPEP_LL
class AscendDeepEPLLOutput(NamedTuple):
"""AscendDeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
seg_indptr: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.ASCENT_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
class DeepEPNormalCombineInput(NamedTuple):
......@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states
else:
if hidden_states.shape[0] > 0:
......@@ -553,16 +536,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
)
if _is_npu:
deepep_output = AscendDeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
self.handle[1],
expected_m,
)
else:
deepep_output = DeepEPLLOutput(
hidden_states,
topk_idx,
......
......@@ -330,6 +330,14 @@ class TopK(CustomOp):
)
topk_weights = topk_weights / topk_weights_sum
if expert_location_dispatch_info is not None:
topk_ids = topk_ids_logical_to_physical(
topk_ids, expert_location_dispatch_info
)
get_global_expert_distribution_recorder().on_select_experts(
topk_ids=topk_ids
)
return StandardTopKOutput(topk_weights, topk_ids, _)
else:
self.topk_config.torch_native = True
......
......@@ -51,5 +51,11 @@ ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
### Install sgl-kernel-npu
SGL_KERNEL_NPU_TAG="20250901"
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so)
### Install SGLang
${PIP_INSTALL} -v -e "python[srt_npu]"
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8": {
"accuracy": 0.95,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendDeepEP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--mem-fraction-static",
0.9,
"--max-running-requests",
32,
"--disable-radix-cache",
"--chunked-prefill-size",
32768,
"--disable-cuda-graph",
"--tp-size",
16,
"--dp-size",
1,
"--ep-size",
16,
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"auto",
]
cls.extra_envs = {
"HCCL_BUFFSIZE": "500",
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "32",
}
os.environ.update(cls.extra_envs)
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=1500,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
......@@ -300,6 +300,9 @@ suite_ascend = {
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
TestFile("ascend/test_ascend_tp4_bf16.py", 400),
],
"per-commit-16-ascend-a3": [
TestFile("ascend/test_ascend_deepep.py", 400),
],
}
suites.update(suite_amd)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment