"vscode:/vscode.git/clone" did not exist on "5366db5df1fb16bd92491ec2b624d619b21b62a5"
Unverified Commit 8fc910db authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

DP Attention with Auto DeepEP Dispatch (#7222)

parent 75354d9a
......@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch, _ = self.prepare_mlp_sync_batch(batch)
batch = self.prepare_mlp_sync_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
......
......@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
......@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
......
......@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
DeepEPMode,
ceil_div,
......@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE):
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode,
forward_batch: ForwardBatch,
):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
......
......@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
......@@ -686,21 +686,21 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_mode: ForwardMode = None,
forward_batch: ForwardBatch,
):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_mode).dispatch_a(
inner_state = self._get_impl(forward_batch).dispatch_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
self._dispatch_intermediate_state = forward_mode, inner_state
self._dispatch_intermediate_state = forward_batch, inner_state
def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_mode, inner_state = self._dispatch_intermediate_state
forward_batch, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state
return self._get_impl(forward_mode).dispatch_b(*inner_state)
return self._get_impl(forward_batch).dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
......@@ -712,24 +712,26 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_mode: ForwardMode,
forward_batch: ForwardBatch,
):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_mode).combine_a(
inner_state = self._get_impl(forward_batch).combine_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
self._combine_intermediate_state = forward_mode, inner_state
self._combine_intermediate_state = forward_batch, inner_state
def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_mode, inner_state = self._combine_intermediate_state
forward_batch, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state)
return self._get_impl(forward_batch).combine_b(*inner_state)
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency:
......
......@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None
......@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
......@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]
......
......@@ -1490,7 +1490,7 @@ class Scheduler(
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
new_batch = self.prepare_mlp_sync_batch(new_batch)
need_dp_attn_preparation = new_batch is None
if new_batch is not None:
......@@ -1506,7 +1506,7 @@ class Scheduler(
# Handle DP attention
if need_dp_attn_preparation:
ret, _ = self.prepare_mlp_sync_batch(ret)
ret = self.prepare_mlp_sync_batch(ret)
return ret
......@@ -1923,8 +1923,7 @@ class Scheduler(
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch)
return local_batch
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
......
......@@ -254,6 +254,7 @@ class ForwardBatch:
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None
......@@ -299,6 +300,7 @@ class ForwardBatch:
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
is_extend_in_batch=batch.is_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths,
......
......@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
forward_batch=forward_batch,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
......@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
forward_batch=forward_batch,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
forward_batch=forward_batch,
)
if shared_output is not None:
......@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
......@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
)
def op_combine_a(self, state):
......@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
......@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
and hidden_states.shape[0] == 0
):
state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch.forward_mode
hidden_states, state.forward_batch
)
else:
state.hidden_states_mlp_output = hidden_states
......
......@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
forward_batch=forward_batch,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
......@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
forward_batch=forward_batch,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
forward_batch=forward_batch,
)
return final_hidden_states
......@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
......@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
)
def op_combine_a(self, state):
......@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
......@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
def op_mlp(self, state):
hidden_states = state.pop("hidden_states_mlp_input")
state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch.forward_mode
)
state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
def op_comm_postprocess_layer(self, state):
hidden_states, residual = self.layer_communicator.postprocess_layer(
......
......@@ -418,10 +418,6 @@ class ServerArgs:
# DeepEP MoE
if self.enable_deepep_moe:
if self.deepep_mode == "auto":
assert (
not self.enable_dp_attention
), "DeepEP MoE `auto` mode is not supported with DP Attention."
if self.deepep_mode == "normal":
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
self.disable_cuda_graph = True
......
......@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import (
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
......@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin:
class TboDPAttentionPreparer:
def prepare_all_gather(
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
self,
local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
):
self.enable_two_batch_overlap = enable_two_batch_overlap
......@@ -294,7 +298,7 @@ class TboDPAttentionPreparer:
extend_lens=local_batch.extend_lens,
token_num_per_seq=token_num_per_seq,
)
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
(
local_batch.forward_mode.is_extend()
......
......@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum):
def enable_low_latency(self):
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
def resolve(self, forward_mode):
def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.auto:
return self
if forward_mode.is_decode():
return DeepEPMode.low_latency
else:
if is_extend_in_batch:
return DeepEPMode.normal
else:
return DeepEPMode.low_latency
def is_non_idle_and_non_empty(forward_mode, hidden_states):
......
......@@ -539,8 +539,9 @@ class Test10(CustomTestCase):
"8",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -593,8 +594,9 @@ class Test11(CustomTestCase):
"4",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -647,8 +649,9 @@ class Test12(CustomTestCase):
"8",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -700,8 +703,9 @@ class Test13(CustomTestCase):
"1",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -756,8 +760,9 @@ class Test14(CustomTestCase):
"1",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -812,8 +817,9 @@ class Test15(CustomTestCase):
"1",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -867,8 +873,9 @@ class Test16(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -922,8 +929,9 @@ class Test17(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -979,8 +987,9 @@ class Test18(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -1036,8 +1045,9 @@ class Test19(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"128",
],
)
......@@ -2213,8 +2223,11 @@ class Test40(CustomTestCase):
"8",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2277,8 +2290,11 @@ class Test41(CustomTestCase):
"4",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2341,8 +2357,11 @@ class Test42(CustomTestCase):
"8",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2404,8 +2423,11 @@ class Test43(CustomTestCase):
"1",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2470,8 +2492,11 @@ class Test44(CustomTestCase):
"1",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2536,8 +2561,11 @@ class Test45(CustomTestCase):
"1",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2601,8 +2629,11 @@ class Test46(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2666,8 +2697,11 @@ class Test47(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2733,8 +2767,11 @@ class Test48(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......@@ -2800,8 +2837,11 @@ class Test49(CustomTestCase):
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"auto",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo",
"NEXTN",
"--speculative-draft",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment