Unverified Commit 6c88f6c8 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)

parent c8d3a402
...@@ -23,7 +23,7 @@ spec: ...@@ -23,7 +23,7 @@ spec:
- /bin/bash - /bin/bash
- -c - -c
# please modify the sglang serving arguments below, as necessary. # please modify the sglang serving arguments below, as necessary.
# NOTE: the --expert-parallel-size and --enable-ep-moe are for MoE model like DeepSeek-R1 # NOTE: the --expert-parallel-size is for MoE model like DeepSeek-R1
args: args:
- | - |
python3 -m sglang.launch_server \ python3 -m sglang.launch_server \
...@@ -36,7 +36,6 @@ spec: ...@@ -36,7 +36,6 @@ spec:
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 8000 \ --port 8000 \
--enable-metrics \ --enable-metrics \
--enable-ep-moe \
--expert-parallel-size 16 --expert-parallel-size 16
env: env:
- name: POD_INDEX # reflects the node-rank - name: POD_INDEX # reflects the node-rank
......
...@@ -39,13 +39,13 @@ $ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 - ...@@ -39,13 +39,13 @@ $ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 -
```bash ```bash
# prefill 0 # prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# prefill 1 # prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# decode 0 # decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
# decode 1 # decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
``` ```
### Advanced Configuration ### Advanced Configuration
...@@ -103,13 +103,13 @@ $ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 - ...@@ -103,13 +103,13 @@ $ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 -
```bash ```bash
# prefill 0 # prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# prefill 1 # prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# decode 0 # decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
# decode 1 # decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
``` ```
## ASCEND ## ASCEND
......
...@@ -212,8 +212,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -212,8 +212,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 | | `--ep-size` | The expert parallelism size. | 1 |
| `--enable-ep-moe` | Enabling expert parallelism for moe. The ep size is equal to the tp size. | False | | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None |
| `--enable-deepep-moe` | Enabling DeepEP MoE implementation for EP MoE. | False |
| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False | | `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False |
| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False | | `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
......
...@@ -28,9 +28,8 @@ spec: ...@@ -28,9 +28,8 @@ spec:
- --enable-dp-lm-head - --enable-dp-lm-head
- --dp-size - --dp-size
- "16" - "16"
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- low_latency
- --disaggregation-mode - --disaggregation-mode
- decode - decode
- --mem-fraction-static - --mem-fraction-static
...@@ -166,9 +165,8 @@ spec: ...@@ -166,9 +165,8 @@ spec:
- --enable-dp-lm-head - --enable-dp-lm-head
- --dp-size - --dp-size
- "16" - "16"
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- low_latency
- --disaggregation-mode - --disaggregation-mode
- decode - decode
- --mem-fraction-static - --mem-fraction-static
......
...@@ -38,9 +38,8 @@ spec: ...@@ -38,9 +38,8 @@ spec:
- --dp-size - --dp-size
- "16" - "16"
- --disable-radix-cache - --disable-radix-cache
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- normal
- --disaggregation-mode - --disaggregation-mode
- prefill - prefill
- --mem-fraction-static - --mem-fraction-static
...@@ -184,9 +183,8 @@ spec: ...@@ -184,9 +183,8 @@ spec:
- --dp-size - --dp-size
- "16" - "16"
- --disable-radix-cache - --disable-radix-cache
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- normal
- --disaggregation-mode - --disaggregation-mode
- prefill - prefill
- --mem-fraction-static - --mem-fraction-static
......
...@@ -64,9 +64,8 @@ spec: ...@@ -64,9 +64,8 @@ spec:
- --dp-size - --dp-size
- "16" - "16"
- --disable-radix-cache - --disable-radix-cache
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- normal
- --disaggregation-mode - --disaggregation-mode
- prefill - prefill
- --mem-fraction-static - --mem-fraction-static
...@@ -212,9 +211,8 @@ spec: ...@@ -212,9 +211,8 @@ spec:
- --dp-size - --dp-size
- "16" - "16"
- --disable-radix-cache - --disable-radix-cache
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- normal
- --disaggregation-mode - --disaggregation-mode
- prefill - prefill
- --mem-fraction-static - --mem-fraction-static
...@@ -373,9 +371,8 @@ spec: ...@@ -373,9 +371,8 @@ spec:
- --enable-dp-lm-head - --enable-dp-lm-head
- --dp-size - --dp-size
- "16" - "16"
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- low_latency
- --disaggregation-mode - --disaggregation-mode
- decode - decode
- --mem-fraction-static - --mem-fraction-static
...@@ -508,9 +505,8 @@ spec: ...@@ -508,9 +505,8 @@ spec:
#- --enable-two-batch-overlap #- --enable-two-batch-overlap
- --dp-size - --dp-size
- "16" - "16"
- --enable-deepep-moe - --moe-a2a-backend
- --deepep-mode - deepep
- low_latency
- --disaggregation-mode - --disaggregation-mode
- decode - decode
- --mem-fraction-static - --mem-fraction-static
......
...@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC): ...@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC):
) )
if server_args.expert_distribution_recorder_mode == "stat_approx": if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"): if server_args.moe_a2a_backend is not None and (
server_args.deepep_mode == "normal"
):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else: else:
raise NotImplementedError raise NotImplementedError
if server_args.enable_deepep_moe: if server_args.moe_a2a_backend is not None:
if server_args.deepep_mode == "normal": if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency": elif server_args.deepep_mode == "low_latency":
......
...@@ -108,7 +108,7 @@ class LayerScatterModes: ...@@ -108,7 +108,7 @@ class LayerScatterModes:
if context.is_layer_sparse: if context.is_layer_sparse:
return ( return (
ScatterMode.SCATTERED ScatterMode.SCATTERED
if global_server_args_dict["enable_deepep_moe"] if not global_server_args_dict["moe_a2a_backend"].is_standard()
else ScatterMode.FULL else ScatterMode.FULL
) )
else: else:
......
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, Optional
import torch import torch
from sglang.srt.distributed import ( from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
ep_scatter, ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preprocess, moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
...@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import ( ...@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import (
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import QuantizationConfig
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import ( from sglang.srt.layers.quantization.fp8 import (
Fp8Config, Fp8Config,
Fp8MoEMethod, Fp8MoEMethod,
...@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import ( ...@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import (
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import ( from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
DeepEPMode,
ceil_div,
dispose_tensor,
get_bool_env_var,
is_hip,
is_npu,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLOutput, DeepEPLLOutput,
DeepEPNormalOutput, DeepEPNormalOutput,
DispatchOutput, DispatchOutput,
...@@ -119,7 +96,6 @@ class EPMoE(FusedMoE): ...@@ -119,7 +96,6 @@ class EPMoE(FusedMoE):
activation=activation, activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input, # apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
enable_ep_moe=True,
) )
self.start_expert_id = self.moe_ep_rank * self.num_local_experts self.start_expert_id = self.moe_ep_rank * self.num_local_experts
...@@ -328,7 +304,7 @@ class DeepEPMoE(EPMoE): ...@@ -328,7 +304,7 @@ class DeepEPMoE(EPMoE):
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.auto, deepep_mode: DeepEPMode = DeepEPMode.AUTO,
): ):
super().__init__( super().__init__(
num_experts=num_experts, num_experts=num_experts,
...@@ -348,7 +324,6 @@ class DeepEPMoE(EPMoE): ...@@ -348,7 +324,6 @@ class DeepEPMoE(EPMoE):
# TODO: move to the beginning of the file # TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher( self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
...@@ -762,11 +737,10 @@ class FlashInferEPMoE(EPMoE): ...@@ -762,11 +737,10 @@ class FlashInferEPMoE(EPMoE):
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE return DeepEPMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]: if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE return FusedMoE
if global_server_args_dict["enable_ep_moe"]: if get_moe_expert_parallel_world_size() > 1:
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
...@@ -14,8 +14,6 @@ from sglang.srt.distributed import ( ...@@ -14,8 +14,6 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size, get_moe_tensor_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
...@@ -94,7 +92,6 @@ class FusedMoE(torch.nn.Module): ...@@ -94,7 +92,6 @@ class FusedMoE(torch.nn.Module):
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False, enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
): ):
super().__init__() super().__init__()
...@@ -112,7 +109,6 @@ class FusedMoE(torch.nn.Module): ...@@ -112,7 +109,6 @@ class FusedMoE(torch.nn.Module):
if enable_flashinfer_cutlass_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.") logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False enable_flashinfer_cutlass_moe = False
enable_ep_moe = False
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.moe_ep_size = get_moe_expert_parallel_world_size() self.moe_ep_size = get_moe_expert_parallel_world_size()
...@@ -121,7 +117,7 @@ class FusedMoE(torch.nn.Module): ...@@ -121,7 +117,7 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank() self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0 assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size self.num_local_experts = num_experts // self.moe_ep_size
if enable_ep_moe: if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion # TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32) self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
......
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLOutput,
DeepEPNormalOutput,
)
__all__ = [
"BaseDispatcher",
"BaseDispatcherConfig",
"DispatchOutput",
"DispatchOutputFormat",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",
"DeepEPLLOutput",
]
...@@ -2,11 +2,22 @@ from __future__ import annotations ...@@ -2,11 +2,22 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
import torch import torch
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self):
return self == MoEA2ABackend.none
def is_deepep(self):
return self == MoEA2ABackend.deepep
class DispatchOutputFormat(Enum): class DispatchOutputFormat(Enum):
standard = auto() standard = auto()
deepep_normal = auto() deepep_normal = auto()
......
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
from __future__ import annotations from __future__ import annotations
import logging import logging
...@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( ...@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
DeepEPMode,
get_bool_env_var,
get_int_env_var,
is_hip,
load_json_config,
)
try: try:
from deep_ep import Buffer, Config from deep_ep import Buffer, Config
...@@ -150,9 +143,9 @@ class DeepEPBuffer: ...@@ -150,9 +143,9 @@ class DeepEPBuffer:
num_rdma_bytes, num_rdma_bytes,
) )
if deepep_mode == DeepEPMode.normal: if deepep_mode == DeepEPMode.NORMAL:
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2 num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]: elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
num_qps_per_rank = num_experts // group.size() num_qps_per_rank = num_experts // group.size()
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -161,7 +154,7 @@ class DeepEPBuffer: ...@@ -161,7 +154,7 @@ class DeepEPBuffer:
device="cuda" device="cuda"
).multi_processor_count ).multi_processor_count
if ( if (
(deepep_mode != DeepEPMode.low_latency) (deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"] and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
): ):
...@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
num_local_experts: int = None, num_local_experts: int = None,
hidden_size: int = None, hidden_size: int = None,
params_dtype: torch.dtype = None, params_dtype: torch.dtype = None,
deepep_mode: DeepEPMode = DeepEPMode.auto, deepep_mode: DeepEPMode = DeepEPMode.AUTO,
async_finish: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, return_recv_hook: bool = False,
): ):
...@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
resolved_deepep_mode = self.deepep_mode.resolve( resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch forward_batch.is_extend_in_batch
) )
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.NORMAL:
return self._normal_dispatcher return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency: elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
return self._low_latency_dispatcher return self._low_latency_dispatcher
else: else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
......
from enum import Enum
class MoeA2ABackend(Enum):
STANDARD = ("standard", "none")
DEEPEP = "deepep"
@classmethod
def _missing_(cls, value):
if value is None:
return cls.STANDARD
for member in cls:
if value in member.value:
return member
raise ValueError(f"No {cls.__name__} member for value {value}")
def is_deepep(self):
return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD
class DeepEPMode(Enum):
NORMAL = "normal"
LOW_LATENCY = "low_latency"
AUTO = "auto"
def enable_normal(self):
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self):
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.AUTO:
return self
if is_extend_in_batch:
return DeepEPMode.NORMAL
else:
return DeepEPMode.LOW_LATENCY
...@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
...@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_dp_attention", "enable_dp_attention",
"enable_two_batch_overlap", "enable_two_batch_overlap",
"enable_dp_lm_head", "enable_dp_lm_head",
"enable_deepep_moe", "moe_a2a_backend",
"deepep_mode", "deepep_mode",
"enable_ep_moe",
"enable_flashinfer_cutlass_moe", "enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe", "enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
......
...@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
...@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm ...@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import ( from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode, DynamicGradMode,
broadcast_pyobj, broadcast_pyobj,
configure_gc_logger, configure_gc_logger,
...@@ -1762,8 +1762,10 @@ class Scheduler( ...@@ -1762,8 +1762,10 @@ class Scheduler(
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap, enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe, enable_deepep_moe=MoeA2ABackend(
deepep_mode=DeepEPMode[self.server_args.deepep_mode], self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule, disable_overlap_schedule=self.server_args.disable_overlap_schedule,
) )
......
...@@ -38,6 +38,7 @@ import torch ...@@ -38,6 +38,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DPPaddingMode, DPPaddingMode,
get_attention_dp_rank, get_attention_dp_rank,
...@@ -839,7 +840,7 @@ class ForwardBatch: ...@@ -839,7 +840,7 @@ class ForwardBatch:
def enable_num_token_non_padded(server_args): def enable_num_token_non_padded(server_args):
return server_args.enable_ep_moe or server_args.enable_deepep_moe return get_moe_expert_parallel_world_size() > 1
class PPProxyTensors: class PPProxyTensors:
......
...@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import ( from sglang.srt.layers.quantization import (
deep_gemm_wrapper, deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer, monkey_patch_isinstance_for_vllm_base_layer,
...@@ -217,6 +218,10 @@ class ModelRunner: ...@@ -217,6 +218,10 @@ class ModelRunner:
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm, "speculative_algorithm": self.spec_algorithm,
} }
| {
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
) )
# CPU offload # CPU offload
......
...@@ -29,6 +29,7 @@ from tqdm import tqdm ...@@ -29,6 +29,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state, parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -61,7 +62,6 @@ from sglang.srt.layers.moe.ep_moe.layer import ( ...@@ -61,7 +62,6 @@ from sglang.srt.layers.moe.ep_moe.layer import (
get_moe_impl_class, get_moe_impl_class,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import ( ...@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import (
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
DeepEPMode,
LazyValue, LazyValue,
add_prefix, add_prefix,
bind_or_assign, bind_or_assign,
...@@ -333,15 +332,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -333,15 +332,14 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**( **(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {} else {}
), ),
# Additional args for FusedMoE # Additional args for FusedMoE
**( **(
dict( dict(
enable_flashinfer_cutlass_moe=True, enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
) )
if global_server_args_dict["enable_flashinfer_cutlass_moe"] if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {} else {}
...@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix), prefix=add_prefix("shared_experts", prefix),
**( **(
dict(tp_rank=0, tp_size=1) dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {} else {}
), ),
) )
...@@ -404,9 +402,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -404,9 +402,9 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.n_routed_experts config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"] + global_server_args_dict["ep_num_redundant_experts"]
...@@ -428,12 +426,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -428,12 +426,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True, async_finish=True,
return_recv_hook=True, return_recv_hook=True,
) )
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"] self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
def get_moe_weights(self): def get_moe_weights(self):
return [ return [
...@@ -2104,11 +2102,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2104,11 +2102,8 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_shared_experts != 1 or self.config.n_shared_experts != 1
): ):
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization." disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif ( elif get_moe_expert_parallel_world_size() > 1:
global_server_args_dict["enable_deepep_moe"] disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True global_server_args_dict["disable_shared_experts_fusion"] = True
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state, parallel_state,
...@@ -50,7 +51,6 @@ from sglang.srt.layers.linear import ( ...@@ -50,7 +51,6 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import ( from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class, get_moe_impl_class,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
...@@ -83,7 +83,6 @@ from sglang.srt.two_batch_overlap import ( ...@@ -83,7 +83,6 @@ from sglang.srt.two_batch_overlap import (
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
DeepEPMode,
LazyValue, LazyValue,
add_prefix, add_prefix,
bind_or_assign, bind_or_assign,
...@@ -443,15 +442,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -443,15 +442,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**( **(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {} else {}
), ),
# Additional args for FusedMoE # Additional args for FusedMoE
**( **(
dict( dict(
enable_flashinfer_cutlass_moe=True, enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
) )
if global_server_args_dict["enable_flashinfer_cutlass_moe"] if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {} else {}
...@@ -484,7 +482,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -484,7 +482,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
prefix=add_prefix("shared_experts", prefix), prefix=add_prefix("shared_experts", prefix),
**( **(
dict(tp_rank=0, tp_size=1) dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {} else {}
), ),
) )
...@@ -502,9 +500,9 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -502,9 +500,9 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.n_routed_experts config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"] + global_server_args_dict["ep_num_redundant_experts"]
...@@ -526,12 +524,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -526,12 +524,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True, async_finish=True,
return_recv_hook=True, return_recv_hook=True,
) )
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"] self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...@@ -737,11 +735,8 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -737,11 +735,8 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
or self.config.n_shared_experts != 1 or self.config.n_shared_experts != 1
): ):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization." disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif ( elif get_moe_expert_parallel_world_size() > 1:
global_server_args_dict["enable_deepep_moe"] disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True global_server_args_dict["disable_shared_experts_fusion"] = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment