Unverified Commit 6c88f6c8 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)

parent c8d3a402
......@@ -23,7 +23,7 @@ spec:
- /bin/bash
- -c
# please modify the sglang serving arguments below, as necessary.
# NOTE: the --expert-parallel-size and --enable-ep-moe are for MoE model like DeepSeek-R1
# NOTE: the --expert-parallel-size is for MoE model like DeepSeek-R1
args:
- |
python3 -m sglang.launch_server \
......@@ -36,7 +36,6 @@ spec:
--host 0.0.0.0 \
--port 8000 \
--enable-metrics \
--enable-ep-moe \
--expert-parallel-size 16
env:
- name: POD_INDEX # reflects the node-rank
......
......@@ -39,13 +39,13 @@ $ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 -
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
```
### Advanced Configuration
......@@ -103,13 +103,13 @@ $ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 -
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
```
## ASCEND
......
......@@ -212,8 +212,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--enable-ep-moe` | Enabling expert parallelism for moe. The ep size is equal to the tp size. | False |
| `--enable-deepep-moe` | Enabling DeepEP MoE implementation for EP MoE. | False |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None |
| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False |
| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
......
......@@ -28,9 +28,8 @@ spec:
- --enable-dp-lm-head
- --dp-size
- "16"
- --enable-deepep-moe
- --deepep-mode
- low_latency
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- decode
- --mem-fraction-static
......@@ -166,9 +165,8 @@ spec:
- --enable-dp-lm-head
- --dp-size
- "16"
- --enable-deepep-moe
- --deepep-mode
- low_latency
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- decode
- --mem-fraction-static
......
......@@ -38,9 +38,8 @@ spec:
- --dp-size
- "16"
- --disable-radix-cache
- --enable-deepep-moe
- --deepep-mode
- normal
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- prefill
- --mem-fraction-static
......@@ -184,9 +183,8 @@ spec:
- --dp-size
- "16"
- --disable-radix-cache
- --enable-deepep-moe
- --deepep-mode
- normal
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- prefill
- --mem-fraction-static
......
......@@ -64,9 +64,8 @@ spec:
- --dp-size
- "16"
- --disable-radix-cache
- --enable-deepep-moe
- --deepep-mode
- normal
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- prefill
- --mem-fraction-static
......@@ -212,9 +211,8 @@ spec:
- --dp-size
- "16"
- --disable-radix-cache
- --enable-deepep-moe
- --deepep-mode
- normal
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- prefill
- --mem-fraction-static
......@@ -373,9 +371,8 @@ spec:
- --enable-dp-lm-head
- --dp-size
- "16"
- --enable-deepep-moe
- --deepep-mode
- low_latency
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- decode
- --mem-fraction-static
......@@ -508,9 +505,8 @@ spec:
#- --enable-two-batch-overlap
- --dp-size
- "16"
- --enable-deepep-moe
- --deepep-mode
- low_latency
- --moe-a2a-backend
- deepep
- --disaggregation-mode
- decode
- --mem-fraction-static
......
......@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC):
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
if server_args.moe_a2a_backend is not None and (
server_args.deepep_mode == "normal"
):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.enable_deepep_moe:
if server_args.moe_a2a_backend is not None:
if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
......
......@@ -108,7 +108,7 @@ class LayerScatterModes:
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if global_server_args_dict["enable_deepep_moe"]
if not global_server_args_dict["moe_a2a_backend"].is_standard()
else ScatterMode.FULL
)
else:
......
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import (
......@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import (
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import (
Fp8Config,
Fp8MoEMethod,
......@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import (
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
)
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
DeepEPMode,
ceil_div,
dispose_tensor,
get_bool_env_var,
is_hip,
is_npu,
)
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
......@@ -119,7 +96,6 @@ class EPMoE(FusedMoE):
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
enable_ep_moe=True,
)
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
......@@ -328,7 +304,7 @@ class DeepEPMoE(EPMoE):
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.auto,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
):
super().__init__(
num_experts=num_experts,
......@@ -348,7 +324,6 @@ class DeepEPMoE(EPMoE):
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
......@@ -762,11 +737,10 @@ class FlashInferEPMoE(EPMoE):
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if global_server_args_dict["enable_ep_moe"]:
if get_moe_expert_parallel_world_size() > 1:
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
......@@ -14,8 +14,6 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
......@@ -94,7 +92,6 @@ class FusedMoE(torch.nn.Module):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()
......@@ -112,7 +109,6 @@ class FusedMoE(torch.nn.Module):
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False
enable_ep_moe = False
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.moe_ep_size = get_moe_expert_parallel_world_size()
......@@ -121,7 +117,7 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
if enable_ep_moe:
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
......
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLOutput,
DeepEPNormalOutput,
)
__all__ = [
"BaseDispatcher",
"BaseDispatcherConfig",
"DispatchOutput",
"DispatchOutputFormat",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",
"DeepEPLLOutput",
]
......@@ -2,11 +2,22 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
import torch
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self):
return self == MoEA2ABackend.none
def is_deepep(self):
return self == MoEA2ABackend.deepep
class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto()
......
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
from __future__ import annotations
import logging
......@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
DeepEPMode,
get_bool_env_var,
get_int_env_var,
is_hip,
load_json_config,
)
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
try:
from deep_ep import Buffer, Config
......@@ -150,9 +143,9 @@ class DeepEPBuffer:
num_rdma_bytes,
)
if deepep_mode == DeepEPMode.normal:
if deepep_mode == DeepEPMode.NORMAL:
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]:
elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
num_qps_per_rank = num_experts // group.size()
else:
raise NotImplementedError
......@@ -161,7 +154,7 @@ class DeepEPBuffer:
device="cuda"
).multi_processor_count
if (
(deepep_mode != DeepEPMode.low_latency)
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
......@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
num_local_experts: int = None,
hidden_size: int = None,
params_dtype: torch.dtype = None,
deepep_mode: DeepEPMode = DeepEPMode.auto,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
async_finish: bool = False,
return_recv_hook: bool = False,
):
......@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
if resolved_deepep_mode == DeepEPMode.NORMAL:
return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency:
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
return self._low_latency_dispatcher
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
......
from enum import Enum
class MoeA2ABackend(Enum):
STANDARD = ("standard", "none")
DEEPEP = "deepep"
@classmethod
def _missing_(cls, value):
if value is None:
return cls.STANDARD
for member in cls:
if value in member.value:
return member
raise ValueError(f"No {cls.__name__} member for value {value}")
def is_deepep(self):
return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD
class DeepEPMode(Enum):
NORMAL = "normal"
LOW_LATENCY = "low_latency"
AUTO = "auto"
def enable_normal(self):
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self):
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.AUTO:
return self
if is_extend_in_batch:
return DeepEPMode.NORMAL
else:
return DeepEPMode.LOW_LATENCY
......@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
......@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_dp_attention",
"enable_two_batch_overlap",
"enable_dp_lm_head",
"enable_deepep_moe",
"moe_a2a_backend",
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion",
......
......@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
......@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
broadcast_pyobj,
configure_gc_logger,
......@@ -1762,8 +1762,10 @@ class Scheduler(
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
enable_deepep_moe=MoeA2ABackend(
self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)
......
......@@ -38,6 +38,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
get_attention_dp_rank,
......@@ -839,7 +840,7 @@ class ForwardBatch:
def enable_num_token_non_padded(server_args):
return server_args.enable_ep_moe or server_args.enable_deepep_moe
return get_moe_expert_parallel_world_size() > 1
class PPProxyTensors:
......
......@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
......@@ -217,6 +218,10 @@ class ModelRunner:
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
| {
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
)
# CPU offload
......
......@@ -29,6 +29,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
......@@ -61,7 +62,6 @@ from sglang.srt.layers.moe.ep_moe.layer import (
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import (
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
LazyValue,
add_prefix,
bind_or_assign,
......@@ -333,15 +332,14 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
......@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"]
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
)
......@@ -404,9 +402,9 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
......@@ -428,12 +426,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
def get_moe_weights(self):
return [
......@@ -2104,11 +2102,8 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_shared_experts != 1
):
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif (
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
......
......@@ -23,6 +23,7 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
......@@ -50,7 +51,6 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
......@@ -83,7 +83,6 @@ from sglang.srt.two_batch_overlap import (
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
LazyValue,
add_prefix,
bind_or_assign,
......@@ -443,15 +442,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
......@@ -484,7 +482,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"]
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
)
......@@ -502,9 +500,9 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
......@@ -526,12 +524,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
......@@ -737,11 +735,8 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif (
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment