Unverified Commit a40229f6 authored by Xun Sun's avatar Xun Sun Committed by GitHub
Browse files

[1/N] Introduce Mooncake Backend and Mooncake EP to Support Elastic EP (#10423)


Co-authored-by: default avatarHank Han <hanhan7630@outlook.com>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 74737b28
......@@ -134,6 +134,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None |
| `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None |
| `--tp-size` | The tensor parallelism size. | 1 |
| `--pp-size` | The pipeline parallelism size. | 1 |
| `--pp-max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None |
......@@ -246,7 +248,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep` or `mooncake`. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | auto |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
......
......@@ -43,6 +43,7 @@ from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_int_env_var,
get_local_ip_auto,
is_cpu,
is_cuda_alike,
is_hip,
......@@ -258,11 +259,14 @@ class GroupCoordinator:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
# a cpu_group to allow direct coordination between processes through
# the CPU. The backend is chosen based on `torch_distributed_backend`
if "mooncake" in torch_distributed_backend:
cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu")
else:
cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
......@@ -1410,6 +1414,17 @@ def init_distributed_environment(
distributed_init_method,
backend,
)
if "mooncake" in backend:
try:
from mooncake import ep as mooncake_ep
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run SGLang with Mooncake Backend."
) from e
mooncake_ep.set_host_ip(get_local_ip_auto())
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
......
......@@ -59,6 +59,7 @@ logger = logging.getLogger(__name__)
class DeepEPMoE(FusedMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
"""
_has_printed = False
......@@ -686,7 +687,7 @@ class DeepEPMoE(FusedMoE):
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep():
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
......
......@@ -16,6 +16,11 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPNormalCombineInput,
DeepEPNormalOutput,
)
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
MooncakeCombineInput,
MooncakeDispatchOutput,
MooncakeEPDispatcher,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
......@@ -30,6 +35,9 @@ __all__ = [
"DispatchOutput",
"DispatchOutputFormat",
"DispatchOutputChecker",
"MooncakeCombineInput",
"MooncakeDispatchOutput",
"MooncakeEPDispatcher",
"StandardDispatchOutput",
"StandardCombineInput",
"DeepEPConfig",
......
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.utils import get_int_env_var
try:
from mooncake.mooncake_ep_buffer import Buffer
use_mooncake_ep = True
except ImportError:
use_mooncake_ep = False
from enum import Enum, IntEnum, auto
import torch
import torch.distributed as dist
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
logger = logging.getLogger(__name__)
class MooncakeDispatchOutput(NamedTuple):
"""Mooncake EP dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.DEEPEP_LL
assert isinstance(MooncakeDispatchOutput, DispatchOutput)
class MooncakeCombineInput(NamedTuple):
"""Mooncake EP combine input."""
pass
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.DEEPEP_LL
assert isinstance(MooncakeCombineInput, CombineInput)
_ACTIVE_RANKS: Optional[torch.Tensor] = None
def get_ep_active_ranks() -> torch.Tensor:
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
return _ACTIVE_RANKS
class EPBuffer:
_buffer = None
_hidden_size: Optional[int] = None
_num_max_dispatch_tokens_per_rank: Optional[int] = None
_num_experts: Optional[int] = None
@classmethod
def get_ep_buffer(
cls,
group: dist.ProcessGroup,
hidden_size: int,
param_bytes: int,
deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = -1,
num_experts: int = -1,
):
if cls._buffer is not None:
return cls._buffer
cls._hidden_size = hidden_size
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
cls._num_experts = num_experts
num_ep_buffer_bytes = 0
if deepep_mode.enable_normal():
raise NotImplementedError(
"Normal mode is not supported for Mooncake EP yet."
)
if deepep_mode.enable_low_latency():
assert num_max_dispatch_tokens_per_rank != -1
assert num_experts != -1 and num_experts % group.size() == 0
num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
num_max_dispatch_tokens_per_rank,
hidden_size,
group.size(),
num_experts,
)
cls._buffer = Buffer(group, num_ep_buffer_bytes)
return cls._buffer
class _MooncakeEPDispatcherImpl:
def __init__(
self,
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool,
num_experts: int,
num_local_experts: int,
hidden_size: int,
params_dtype: torch.dtype,
return_recv_hook: bool,
deepep_mode: DeepEPMode,
):
if not use_mooncake_ep:
raise ImportError(
"Mooncake EP is not installed. Please install Mooncake package at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
"with EP support to run SGLang with Mooncake EP."
)
self.group = group
self.router_topk = router_topk
self.permute_fusion = permute_fusion
self.num_experts = num_experts
self.num_local_experts = num_local_experts
self.hidden_size = hidden_size
self.params_dtype = params_dtype
self.return_recv_hook = return_recv_hook
self.deepep_mode = deepep_mode
self.params_bytes = 2
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
"SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
)
# Mooncake EP dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
assert self.num_max_dispatch_tokens_per_rank <= 1024
self.first_execution = True
self.timeout_us = 10000000
global _ACTIVE_RANKS
if _ACTIVE_RANKS is None:
_ACTIVE_RANKS = torch.ones(
(self.num_experts,), dtype=torch.int32, device="cuda"
)
self.active_ranks = _ACTIVE_RANKS
self.handle = None
def dispatch_a(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
buffer = self._get_buffer()
topk_idx = topk_idx.to(torch.int64)
expected_m = (
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
+ self.num_experts
) // self.num_experts
hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states,
topk_idx,
use_fp8=True,
)
return (
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
event,
hook,
)
def dispatch_b(
self,
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
event,
hook,
):
hook() if self.return_recv_hook else event.current_stream_wait()
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
masked_m
)
return MooncakeDispatchOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
def _dispatch_core(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
use_fp8: bool = False,
):
buffer = self._get_buffer()
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
buffer.dispatch(
hidden_states,
topk_idx,
self.active_ranks,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
-1 if self.first_execution else self.timeout_us,
use_fp8=use_fp8,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
)
return packed_recv_hidden, packed_recv_count, event, hook
def combine_a(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
hidden_states, event, hook = self._combine_core(
hidden_states,
topk_idx,
topk_weights,
)
return hidden_states, event, hook
def combine_b(self, hidden_states, event, hook):
hook() if self.return_recv_hook else event.current_stream_wait()
return hidden_states
def _combine_core(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
buffer = self._get_buffer()
combined_hidden_states, event, hook = buffer.combine(
hidden_states,
topk_idx,
topk_weights,
self.active_ranks,
-1 if self.first_execution else self.timeout_us,
self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
self.first_execution = False
self.handle = None
return combined_hidden_states, event, hook
def _get_buffer(self):
return EPBuffer.get_ep_buffer(
self.group,
self.hidden_size,
self.params_bytes,
self.deepep_mode,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
)
@dataclass
class _Stage(Enum):
INITIAL = auto()
AFTER_DISPATCH_A = auto()
AFTER_DISPATCH_B = auto()
AFTER_COMBINE_A = auto()
class MooncakeEPDispatcher(BaseDispatcher):
def __init__(
self,
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool = False,
num_experts: int = None,
num_local_experts: int = None,
hidden_size: int = None,
params_dtype: torch.dtype = None,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
async_finish: bool = False,
return_recv_hook: bool = False,
):
self.deepep_mode = deepep_mode
if self.deepep_mode.enable_low_latency():
self._low_latency_dispatcher = _MooncakeEPDispatcherImpl(
group=group,
router_topk=router_topk,
permute_fusion=permute_fusion,
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
return_recv_hook=return_recv_hook,
deepep_mode=deepep_mode,
)
if self.deepep_mode.enable_normal():
raise NotImplementedError
self._stage = _Stage.INITIAL
def dispatch(self, *args, **kwargs) -> DispatchOutput:
self.dispatch_a(*args, **kwargs)
ret = self.dispatch_b()
return ret
def dispatch_a(
self,
hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor],
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_batch).dispatch_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
self._dispatch_intermediate_state = forward_batch, inner_state
def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_batch, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state
return self._get_impl(forward_batch).dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
ret = self.combine_b()
return ret
def combine_a(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional = None,
):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_batch).combine_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
self._combine_intermediate_state = forward_batch, inner_state
def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_batch, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state
return self._get_impl(forward_batch).combine_b(*inner_state)
def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl:
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.NORMAL:
raise NotImplementedError
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
return self._low_latency_dispatcher
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
def _update_stage(self, old_stage, new_stage):
assert self._stage == old_stage
self._stage = new_stage
......@@ -24,6 +24,7 @@ class MoeA2ABackend(Enum):
NONE = "none"
DEEPEP = "deepep"
MOONCAKE = "mooncake"
@classmethod
def _missing_(cls, value):
......@@ -40,6 +41,9 @@ class MoeA2ABackend(Enum):
def is_deepep(self):
return self == MoeA2ABackend.DEEPEP
def is_mooncake(self):
return self == MoeA2ABackend.MOONCAKE
class MoeRunnerBackend(Enum):
......
......@@ -677,7 +677,18 @@ class ModelRunner:
raise
if self.device == "cuda":
backend = "nccl"
if self.server_args.elastic_ep_backend == "mooncake":
backend = "mooncake"
if self.server_args.mooncake_ib_device:
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
try:
from mooncake import ep as mooncake_ep
mooncake_ep.set_device_filter(mooncake_ib_device)
except:
pass # A warning will be raised in `init_distributed_environment`
else:
backend = "nccl"
elif self.device == "xpu":
backend = "xccl"
elif self.device == "hpu":
......@@ -885,17 +896,23 @@ class ModelRunner:
f"mem usage={self.weight_load_mem_usage:.2f} GB."
)
# Handle the case where some ranks do not finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
if self.server_args.elastic_ep_backend == "mooncake":
# Mooncake does not support `monitored_barrier`
dist.barrier(group=get_tp_group().cpu_group)
else:
# Handle the case where some ranks do not finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(
seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
def update_expert_location(
self,
......
......@@ -592,6 +592,7 @@ class DeepseekV2MoE(nn.Module):
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
or get_moe_a2a_backend().is_mooncake()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else {}
),
......@@ -622,7 +623,7 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok
if get_moe_a2a_backend().is_deepep():
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
......@@ -651,7 +652,9 @@ class DeepseekV2MoE(nn.Module):
return_recv_hook=True,
)
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
self._enable_a2a_moe = (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
)
def get_moe_weights(self):
return [
......@@ -668,7 +671,7 @@ class DeepseekV2MoE(nn.Module):
use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
if not self._enable_deepep_moe:
if not self._enable_a2a_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
......
......@@ -228,6 +228,8 @@ class ServerArgs:
# Runtime options
device: Optional[str] = None
elastic_ep_backend: Literal[None, "mooncake"] = None
mooncake_ib_device: Optional[str] = None
tp_size: int = 1
pp_size: int = 1
pp_max_micro_batch_size: Optional[int] = None
......@@ -344,7 +346,7 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_a2a_backend: Literal["none", "deepep", "mooncake"] = "none"
moe_runner_backend: str = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False
......@@ -537,7 +539,7 @@ class ServerArgs:
# Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_a2a_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
......@@ -1091,7 +1093,7 @@ class ServerArgs:
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
def _handle_deepep_moe(self):
def _handle_a2a_moe(self):
if self.moe_a2a_backend == "deepep":
if self.deepep_mode == "normal":
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
......@@ -1101,6 +1103,12 @@ class ServerArgs:
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
if self.moe_a2a_backend == "mooncake":
self.ep_size = self.tp_size
logger.warning(
f"Mooncake MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
def _handle_eplb_and_dispatch(self):
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
self.expert_distribution_recorder_mode = "stat"
......@@ -1712,6 +1720,21 @@ class ServerArgs:
default=ServerArgs.device,
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
)
parser.add_argument(
"--elastic-ep-backend",
type=str,
default=ServerArgs.elastic_ep_backend,
choices=["none", "mooncake"],
help="Specify the collective communication backend for elastic EP. Currently supports 'mooncake'.",
)
parser.add_argument(
"--mooncake-ib-device",
type=str,
default=ServerArgs.mooncake_ib_device,
help="The InfiniBand devices for Mooncake Backend transfer, accepts multiple comma-separated devices "
"(e.g., --mooncake-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when Mooncake Backend is enabled.",
)
parser.add_argument(
"--tensor-parallel-size",
"--tp-size",
......@@ -2333,7 +2356,7 @@ class ServerArgs:
parser.add_argument(
"--moe-a2a-backend",
type=str,
choices=["none", "deepep"],
choices=["none", "deepep", "mooncake"],
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
......
......@@ -20,7 +20,10 @@ from sglang.srt.layers.moe import (
get_tbo_token_distribution_threshold,
is_tbo_enabled,
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPDispatcher,
MooncakeEPDispatcher,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import (
......@@ -363,7 +366,7 @@ class TboDPAttentionPreparer:
):
deepep_mode = get_deepep_mode()
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
enable_a2a_moe = not get_moe_a2a_backend().is_none()
enable_two_batch_overlap = is_tbo_enabled()
self.enable_two_batch_overlap = enable_two_batch_overlap
......@@ -392,7 +395,7 @@ class TboDPAttentionPreparer:
local_batch.forward_mode.is_extend()
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and enable_a2a_moe
and (resolved_deepep_mode.is_low_latency())
)
else:
......@@ -968,9 +971,14 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs):
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
if get_moe_a2a_backend().is_deepep():
self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
elif get_moe_a2a_backend().is_mooncake():
self._inners = [
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
......
......@@ -10,6 +10,10 @@ export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export PATH="${NVSHMEM_DIR}/bin:$PATH"
export CUDA_HOME=/usr/local/cuda
# Install Mooncake+EP
curl -L https://cloud.tsinghua.edu.cn/f/c22ec766545e48bf99e8/?dl=1 -o mooncake_transfer_engine-0.3.6.post1+ep-cp310-cp310-manylinux_2_17_x86_64.manylinux_2_35_x86_64.whl
UV_SYSTEM_PYTHON=true uv pip install mooncake_transfer_engine-0.3.6.post1+ep-cp310-cp310-manylinux_2_17_x86_64.manylinux_2_35_x86_64.whl
if python3 -c "import deep_ep" >/dev/null 2>&1; then
echo "deep_ep is already installed or importable. Skipping installation."
exit 0
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPureDP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"512",
"--mem-fraction-static",
"0.5",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestHybridDPTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"2",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"256",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"128",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestNoGatherdBuffer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"512",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestTBO(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--enable-two-batch-overlap",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"512",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
if __name__ == "__main__":
unittest.main()
......@@ -171,6 +171,7 @@ suites = {
],
"per-commit-4-gpu-deepep": [
TestFile("ep/test_deepep_small.py", 531),
TestFile("ep/test_mooncake_ep_small.py", 450),
],
"per-commit-8-gpu-deepep": [
TestFile("ep/test_deepep_large.py", 338),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment