Unverified Commit 8e10fec9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Small refactor DeepEPMode to clean up code a bit (#4992)

parent e8999b13
...@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -47,7 +47,6 @@ if _is_cuda: ...@@ -47,7 +47,6 @@ if _is_cuda:
else: else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_hip = is_hip() _is_hip = is_hip()
...@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE): ...@@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
activation: str = "silu", activation: str = "silu",
deepep_mode: str = "auto", deepep_mode: DeepEPMode = DeepEPMode.auto,
): ):
super().__init__( super().__init__(
num_experts, num_experts,
...@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE): ...@@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE):
activation, activation,
) )
self.deepep_mode = deepep_mode self.deepep_mode = deepep_mode
if self.deepep_mode in ["low_latency", "auto"]: if self.deepep_mode.enable_low_latency():
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm" assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = ( self.w13_weight_fp8 = (
self.w13_weight, self.w13_weight,
...@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE): ...@@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE):
expected_m: int, expected_m: int,
forward_mode: ForwardMode, forward_mode: ForwardMode,
): ):
if self.deepep_mode == "normal" or ( resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
self.deepep_mode == "auto" and not forward_mode.is_decode() if resolved_deepep_mode == DeepEPMode.normal:
):
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif self.deepep_mode == "low_latency" or ( elif resolved_deepep_mode == DeepEPMode.low_latency:
self.deepep_mode == "auto" and forward_mode.is_decode()
):
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else: else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
......
from sglang.srt.utils import DeepEPMode
try: try:
from deep_ep import Buffer from deep_ep import Buffer
...@@ -98,7 +100,7 @@ class DeepEPDispatcher: ...@@ -98,7 +100,7 @@ class DeepEPDispatcher:
num_local_experts: int = None, num_local_experts: int = None,
hidden_size: int = None, hidden_size: int = None,
params_dtype: torch.dtype = None, params_dtype: torch.dtype = None,
deepep_mode: str = "auto", deepep_mode: DeepEPMode = DeepEPMode.auto,
async_finish: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, return_recv_hook: bool = False,
): ):
...@@ -120,13 +122,13 @@ class DeepEPDispatcher: ...@@ -120,13 +122,13 @@ class DeepEPDispatcher:
self.deepep_mode = deepep_mode self.deepep_mode = deepep_mode
self.handle = None self.handle = None
if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode if self.deepep_mode.enable_normal():
self.buffer_normal = get_buffer_normal( self.buffer_normal = get_buffer_normal(
self.group, self.hidden_size * self.params_bytes self.group, self.hidden_size * self.params_bytes
) )
self.async_finish = async_finish self.async_finish = async_finish
self.src2dst = None self.src2dst = None
if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode if self.deepep_mode.enable_low_latency():
""" """
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
...@@ -196,9 +198,8 @@ class DeepEPDispatcher: ...@@ -196,9 +198,8 @@ class DeepEPDispatcher:
) )
expected_m = 0 expected_m = 0
if self.deepep_mode == "normal" or ( resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
self.deepep_mode == "auto" and not forward_mode.is_decode() if resolved_deepep_mode == DeepEPMode.normal:
):
( (
hidden_states, hidden_states,
topk_idx, topk_idx,
...@@ -210,9 +211,7 @@ class DeepEPDispatcher: ...@@ -210,9 +211,7 @@ class DeepEPDispatcher:
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
) )
elif self.deepep_mode == "low_latency" or ( elif resolved_deepep_mode == DeepEPMode.low_latency:
self.deepep_mode == "auto" and forward_mode.is_decode()
):
expected_m = ( expected_m = (
hidden_states.shape[0] hidden_states.shape[0]
* self.buffer_low_latency.group_size * self.buffer_low_latency.group_size
...@@ -354,9 +353,8 @@ class DeepEPDispatcher: ...@@ -354,9 +353,8 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode, forward_mode: ForwardMode,
) -> torch.Tensor: ) -> torch.Tensor:
if self.deepep_mode == "normal" or ( resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
self.deepep_mode == "auto" and not forward_mode.is_decode() if resolved_deepep_mode == DeepEPMode.normal:
):
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty( output = torch.empty(
...@@ -384,9 +382,7 @@ class DeepEPDispatcher: ...@@ -384,9 +382,7 @@ class DeepEPDispatcher:
output, output,
) )
event.current_stream_wait() if self.async_finish else () event.current_stream_wait() if self.async_finish else ()
elif self.deepep_mode == "low_latency" or ( elif resolved_deepep_mode == DeepEPMode.low_latency:
self.deepep_mode == "auto" and forward_mode.is_decode()
):
hidden_states, event, hook = self.combine_low_latency( hidden_states, event, hook = self.combine_low_latency(
hidden_states, hidden_states,
topk_idx, topk_idx,
......
...@@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder ...@@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_hip from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
deepep_mode=global_server_args_dict["deepep_mode"], deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
) )
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
...@@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"], deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO async_finish=True, # TODO
return_recv_hook=True, return_recv_hook=True,
) )
......
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import os import os
import random import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Literal, Optional
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
...@@ -161,7 +161,7 @@ class ServerArgs: ...@@ -161,7 +161,7 @@ class ServerArgs:
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[str] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None cuda_graph_max_bs: Optional[int] = None
......
...@@ -37,6 +37,7 @@ import time ...@@ -37,6 +37,7 @@ import time
import traceback import traceback
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec from importlib.util import find_spec
...@@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list): ...@@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list):
] ]
else: else:
return [nested_list] return [nested_list]
class DeepEPMode(Enum):
normal = "normal"
low_latency = "low_latency"
auto = "auto"
def enable_normal(self):
return self in [DeepEPMode.normal, DeepEPMode.auto]
def enable_low_latency(self):
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
def resolve(self, forward_mode):
if self != DeepEPMode.auto:
return self
if forward_mode.is_decode():
return DeepEPMode.low_latency
else:
return DeepEPMode.normal
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment