Unverified Commit 8fc910db authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

DP Attention with Auto DeepEP Dispatch (#7222)

parent 75354d9a
...@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch_in_queue = last_batch_in_queue self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False): def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch, _ = self.prepare_mlp_sync_batch(batch) batch = self.prepare_mlp_sync_batch(batch)
result = None result = None
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
......
...@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args): if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch) batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args): if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch) batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
......
...@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import ( from sglang.srt.utils import (
DeepEPMode, DeepEPMode,
ceil_div, ceil_div,
...@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE): ...@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE):
masked_m: torch.Tensor, masked_m: torch.Tensor,
expected_m: int, expected_m: int,
num_recv_tokens_per_expert: List[int], num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode, forward_batch: ForwardBatch,
): ):
if _use_aiter: if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights) return self.forward_aiter(hidden_states, topk_idx, topk_weights)
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous( return self.forward_deepgemm_contiguous(
......
...@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel, deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess, deepep_run_moe_deep_preprocess,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
...@@ -686,21 +686,21 @@ class DeepEPDispatcher: ...@@ -686,21 +686,21 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode = None, forward_batch: ForwardBatch,
): ):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_mode).dispatch_a( inner_state = self._get_impl(forward_batch).dispatch_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
) )
self._dispatch_intermediate_state = forward_mode, inner_state self._dispatch_intermediate_state = forward_batch, inner_state
def dispatch_b(self): def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_mode, inner_state = self._dispatch_intermediate_state forward_batch, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl(forward_mode).dispatch_b(*inner_state) return self._get_impl(forward_batch).dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple: def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs) self.combine_a(*args, **kwargs)
...@@ -712,24 +712,26 @@ class DeepEPDispatcher: ...@@ -712,24 +712,26 @@ class DeepEPDispatcher:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode, forward_batch: ForwardBatch,
): ):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_mode).combine_a( inner_state = self._get_impl(forward_batch).combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
) )
self._combine_intermediate_state = forward_mode, inner_state self._combine_intermediate_state = forward_batch, inner_state
def combine_b(self): def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_mode, inner_state = self._combine_intermediate_state forward_batch, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state) return self._get_impl(forward_batch).combine_b(*inner_state)
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase: def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency: elif resolved_deepep_mode == DeepEPMode.low_latency:
......
...@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None tbo_split_seq_index: Optional[int] = None
...@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_ids_logprobs=self.token_ids_logprobs, token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens, global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index, tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode, global_forward_mode=self.global_forward_mode,
...@@ -1798,6 +1800,7 @@ class ModelWorkerBatch: ...@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]] global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
can_run_dp_cuda_graph: bool can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int] tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode] global_forward_mode: Optional[ForwardMode]
......
...@@ -1490,7 +1490,7 @@ class Scheduler( ...@@ -1490,7 +1490,7 @@ class Scheduler(
if need_dp_attn_preparation and not self.spec_algorithm.is_none(): if need_dp_attn_preparation and not self.spec_algorithm.is_none():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group. # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group. # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch, _ = self.prepare_mlp_sync_batch(new_batch) new_batch = self.prepare_mlp_sync_batch(new_batch)
need_dp_attn_preparation = new_batch is None need_dp_attn_preparation = new_batch is None
if new_batch is not None: if new_batch is not None:
...@@ -1506,7 +1506,7 @@ class Scheduler( ...@@ -1506,7 +1506,7 @@ class Scheduler(
# Handle DP attention # Handle DP attention
if need_dp_attn_preparation: if need_dp_attn_preparation:
ret, _ = self.prepare_mlp_sync_batch(ret) ret = self.prepare_mlp_sync_batch(ret)
return ret return ret
...@@ -1923,8 +1923,7 @@ class Scheduler( ...@@ -1923,8 +1923,7 @@ class Scheduler(
if not disable_cuda_graph: if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here. return local_batch
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self): def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new( idle_batch = ScheduleBatch.init_new(
......
...@@ -254,6 +254,7 @@ class ForwardBatch: ...@@ -254,6 +254,7 @@ class ForwardBatch:
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None gathered_buffer: Optional[torch.Tensor] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None global_forward_mode: Optional[ForwardMode] = None
...@@ -299,6 +300,7 @@ class ForwardBatch: ...@@ -299,6 +300,7 @@ class ForwardBatch:
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs, token_ids_logprobs=batch.token_ids_logprobs,
is_extend_in_batch=batch.is_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode, global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
......
...@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_mode=forward_mode, forward_batch=forward_batch,
) )
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
masked_m=masked_m, masked_m=masked_m,
expected_m=expected_m, expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert, num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode, forward_batch=forward_batch,
) )
if self.ep_size > 1: if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states, hidden_states=final_hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_mode=forward_mode, forward_batch=forward_batch,
) )
if shared_output is not None: if shared_output is not None:
...@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states=state.hidden_states_mlp_input, hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"), topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"), topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode, forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
...@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
masked_m=state.pop("masked_m"), masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"), expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode, forward_batch=state.forward_batch,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
...@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states=state.pop("hidden_states_experts_output"), hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"), topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"), topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode, forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
...@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
and hidden_states.shape[0] == 0 and hidden_states.shape[0] == 0
): ):
state.hidden_states_mlp_output = self.mlp( state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch.forward_mode hidden_states, state.forward_batch
) )
else: else:
state.hidden_states_mlp_output = hidden_states state.hidden_states_mlp_output = hidden_states
......
...@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_mode=forward_mode, forward_batch=forward_batch,
) )
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m=masked_m, masked_m=masked_m,
expected_m=expected_m, expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert, num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode, forward_batch=forward_batch,
) )
if self.ep_size > 1: if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states, hidden_states=final_hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_mode=forward_mode, forward_batch=forward_batch,
) )
return final_hidden_states return final_hidden_states
...@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states=state.pop("hidden_states_mlp_input"), hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"), topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"), topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode, forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
...@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m=state.pop("masked_m"), masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"), expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode, forward_batch=state.forward_batch,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
...@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states=state.pop("hidden_states_experts_output"), hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"), topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"), topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode, forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
...@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
def op_mlp(self, state): def op_mlp(self, state):
hidden_states = state.pop("hidden_states_mlp_input") hidden_states = state.pop("hidden_states_mlp_input")
state.hidden_states_mlp_output = self.mlp( state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
hidden_states, state.forward_batch.forward_mode
)
def op_comm_postprocess_layer(self, state): def op_comm_postprocess_layer(self, state):
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
......
...@@ -418,10 +418,6 @@ class ServerArgs: ...@@ -418,10 +418,6 @@ class ServerArgs:
# DeepEP MoE # DeepEP MoE
if self.enable_deepep_moe: if self.enable_deepep_moe:
if self.deepep_mode == "auto":
assert (
not self.enable_dp_attention
), "DeepEP MoE `auto` mode is not supported with DP Attention."
if self.deepep_mode == "normal": if self.deepep_mode == "normal":
logger.warning("Cuda graph is disabled because deepep_mode=`normal`") logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
self.disable_cuda_graph = True self.disable_cuda_graph = True
......
...@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import ( ...@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import (
) )
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
...@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin: ...@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin:
class TboDPAttentionPreparer: class TboDPAttentionPreparer:
def prepare_all_gather( def prepare_all_gather(
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap self,
local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
): ):
self.enable_two_batch_overlap = enable_two_batch_overlap self.enable_two_batch_overlap = enable_two_batch_overlap
...@@ -294,7 +298,7 @@ class TboDPAttentionPreparer: ...@@ -294,7 +298,7 @@ class TboDPAttentionPreparer:
extend_lens=local_batch.extend_lens, extend_lens=local_batch.extend_lens,
token_num_per_seq=token_num_per_seq, token_num_per_seq=token_num_per_seq,
) )
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode) resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not ( local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
( (
local_batch.forward_mode.is_extend() local_batch.forward_mode.is_extend()
......
...@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum): ...@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum):
def enable_low_latency(self): def enable_low_latency(self):
return self in [DeepEPMode.low_latency, DeepEPMode.auto] return self in [DeepEPMode.low_latency, DeepEPMode.auto]
def resolve(self, forward_mode): def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.auto: if self != DeepEPMode.auto:
return self return self
if forward_mode.is_decode(): if is_extend_in_batch:
return DeepEPMode.low_latency
else:
return DeepEPMode.normal return DeepEPMode.normal
else:
return DeepEPMode.low_latency
def is_non_idle_and_non_empty(forward_mode, hidden_states): def is_non_idle_and_non_empty(forward_mode, hidden_states):
......
...@@ -539,8 +539,9 @@ class Test10(CustomTestCase): ...@@ -539,8 +539,9 @@ class Test10(CustomTestCase):
"8", "8",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -593,8 +594,9 @@ class Test11(CustomTestCase): ...@@ -593,8 +594,9 @@ class Test11(CustomTestCase):
"4", "4",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -647,8 +649,9 @@ class Test12(CustomTestCase): ...@@ -647,8 +649,9 @@ class Test12(CustomTestCase):
"8", "8",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -700,8 +703,9 @@ class Test13(CustomTestCase): ...@@ -700,8 +703,9 @@ class Test13(CustomTestCase):
"1", "1",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -756,8 +760,9 @@ class Test14(CustomTestCase): ...@@ -756,8 +760,9 @@ class Test14(CustomTestCase):
"1", "1",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -812,8 +817,9 @@ class Test15(CustomTestCase): ...@@ -812,8 +817,9 @@ class Test15(CustomTestCase):
"1", "1",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -867,8 +873,9 @@ class Test16(CustomTestCase): ...@@ -867,8 +873,9 @@ class Test16(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -922,8 +929,9 @@ class Test17(CustomTestCase): ...@@ -922,8 +929,9 @@ class Test17(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -979,8 +987,9 @@ class Test18(CustomTestCase): ...@@ -979,8 +987,9 @@ class Test18(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -1036,8 +1045,9 @@ class Test19(CustomTestCase): ...@@ -1036,8 +1045,9 @@ class Test19(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"128",
], ],
) )
...@@ -2213,8 +2223,11 @@ class Test40(CustomTestCase): ...@@ -2213,8 +2223,11 @@ class Test40(CustomTestCase):
"8", "8",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2277,8 +2290,11 @@ class Test41(CustomTestCase): ...@@ -2277,8 +2290,11 @@ class Test41(CustomTestCase):
"4", "4",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2341,8 +2357,11 @@ class Test42(CustomTestCase): ...@@ -2341,8 +2357,11 @@ class Test42(CustomTestCase):
"8", "8",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2404,8 +2423,11 @@ class Test43(CustomTestCase): ...@@ -2404,8 +2423,11 @@ class Test43(CustomTestCase):
"1", "1",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2470,8 +2492,11 @@ class Test44(CustomTestCase): ...@@ -2470,8 +2492,11 @@ class Test44(CustomTestCase):
"1", "1",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2536,8 +2561,11 @@ class Test45(CustomTestCase): ...@@ -2536,8 +2561,11 @@ class Test45(CustomTestCase):
"1", "1",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2601,8 +2629,11 @@ class Test46(CustomTestCase): ...@@ -2601,8 +2629,11 @@ class Test46(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2666,8 +2697,11 @@ class Test47(CustomTestCase): ...@@ -2666,8 +2697,11 @@ class Test47(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2733,8 +2767,11 @@ class Test48(CustomTestCase): ...@@ -2733,8 +2767,11 @@ class Test48(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
...@@ -2800,8 +2837,11 @@ class Test49(CustomTestCase): ...@@ -2800,8 +2837,11 @@ class Test49(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode", "--deepep-mode",
"normal", "auto",
"--disable-cuda-graph", "--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "NEXTN",
"--speculative-draft", "--speculative-draft",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment