Unverified Commit f9bab3d5 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

qwen3moe support two batch overlap (#6598)

parent 16f69b1f
...@@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation ...@@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module): ...@@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module):
hidden_states = pp_proxy_tensors["hidden_states"] hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"] residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer): if forward_batch.can_run_tbo:
with get_global_expert_distribution_recorder().with_current_layer(i): hidden_states, residual = model_forward_maybe_tbo(
layer = self.layers[i] layers=self.layers,
hidden_states, residual = layer( enable_tbo=True,
positions, hidden_states, forward_batch, residual positions=positions,
) forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
)
else:
for i in range(self.start_layer, self.end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
return PPProxyTensors( return PPProxyTensors(
{ {
......
...@@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
Qwen3MoeConfig = None Qwen3MoeConfig = None
...@@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.deepep_dispatcher = DeepEPDispatcher( self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group, group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k, router_topk=self.top_k,
permute_fusion=True, permute_fusion=True,
...@@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m, masked_m,
expected_m, expected_m,
) = self.deepep_dispatcher.dispatch( ) = self.deepep_dispatcher.dispatch(
hidden_states, hidden_states=hidden_states,
topk_idx, topk_idx=topk_idx,
topk_weights, topk_weights=topk_weights,
forward_mode=forward_mode, forward_mode=forward_mode,
) )
final_hidden_states = self.experts( final_hidden_states = self.experts(
...@@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
) )
if self.ep_size > 1: if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine( final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states, hidden_states=final_hidden_states,
topk_idx, topk_idx=topk_idx,
topk_weights, topk_weights=topk_weights,
forward_mode, forward_mode=forward_mode,
) )
return final_hidden_states return final_hidden_states
def op_gate(self, state):
if is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
else:
state.router_logits = None
def op_select_experts(self, state):
router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state):
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_dispatch_b(self, state):
if self.ep_size > 1:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
)
def op_combine_a(self, state):
if self.ep_size > 1:
self.deepep_dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_combine_b(self, state):
if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_output(self, state):
state.hidden_states_mlp_output = state.pop("hidden_states_after_combine")
class Qwen3MoeAttention(nn.Module): class Qwen3MoeAttention(nn.Module):
def __init__( def __init__(
...@@ -339,20 +435,54 @@ class Qwen3MoeAttention(nn.Module): ...@@ -339,20 +435,54 @@ class Qwen3MoeAttention(nn.Module):
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
return q, k return q, k
def forward( def op_prepare(self, state):
state.attn_intermediate_state = self.forward_prepare(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
)
def op_core(self, state):
state.hidden_states_after_attn = self.forward_core(
state.pop("attn_intermediate_state")
)
def forward_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ):
if hidden_states.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch) inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
return self.forward_core(s)
class Qwen3MoeDecoderLayer(nn.Module): class Qwen3MoeDecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
def op_comm_prepare_attn(
self,
state,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
tbo_subbatch_index: Optional[int] = None,
):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
)
state.update(
dict(
forward_batch=forward_batch,
positions=positions,
tbo_subbatch_index=tbo_subbatch_index,
)
)
def op_comm_prepare_mlp(self, state):
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
self.layer_communicator.prepare_mlp(
state.pop("hidden_states_after_attn"),
state.pop("residual_after_input_ln"),
state.forward_batch,
)
)
def op_mlp(self, state):
hidden_states = state.pop("hidden_states_mlp_input")
state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch.forward_mode
)
def op_comm_postprocess_layer(self, state):
hidden_states, residual = self.layer_communicator.postprocess_layer(
state.pop("hidden_states_mlp_output"),
state.pop("residual_after_comm_pre_mlp"),
state.forward_batch,
)
output = dict(
positions=state.positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.tbo_subbatch_index,
)
state.clear(
expect_keys={
"positions",
"forward_batch",
"tbo_subbatch_index",
}
)
return output
class Qwen3MoeModel(Qwen2MoeModel): class Qwen3MoeModel(Qwen2MoeModel):
def __init__( def __init__(
......
...@@ -32,12 +32,27 @@ class OperationsStrategy: ...@@ -32,12 +32,27 @@ class OperationsStrategy:
layers: torch.nn.ModuleList, layers: torch.nn.ModuleList,
forward_mode: ForwardMode, forward_mode: ForwardMode,
) -> "OperationsStrategy": ) -> "OperationsStrategy":
return OperationsStrategy.concat( layer_name = layers[0].__class__.__name__
[ if layer_name == "DeepseekV2DecoderLayer":
_compute_layer_operations_strategy_tbo(layer, forward_mode) return OperationsStrategy.concat(
for layer in layers [
] _compute_moe_deepseek_layer_operations_strategy_tbo(
) layer, forward_mode
)
for layer in layers
]
)
elif layer_name == "Qwen3MoeDecoderLayer":
return OperationsStrategy.concat(
[
_compute_moe_qwen3_layer_operations_strategy_tbo(
layer, forward_mode
)
for layer in layers
]
)
else:
raise NotImplementedError
def _assert_all_same(items: List): def _assert_all_same(items: List):
...@@ -45,8 +60,11 @@ def _assert_all_same(items: List): ...@@ -45,8 +60,11 @@ def _assert_all_same(items: List):
return items[0] return items[0]
# -------------------------------- Strategy for DeepSeek ---------------------------------------
# TODO can refactor to make it more fancy if we have more complex strategies # TODO can refactor to make it more fancy if we have more complex strategies
def _compute_layer_operations_strategy_tbo( def _compute_moe_deepseek_layer_operations_strategy_tbo(
layer: torch.nn.Module, layer: torch.nn.Module,
forward_mode: ForwardMode, forward_mode: ForwardMode,
) -> OperationsStrategy: ) -> OperationsStrategy:
...@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer): ...@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer):
operations.YieldOperation(), operations.YieldOperation(),
], ],
) )
# -------------------------------- Strategy for Qwen3 ---------------------------------------
# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
# convenience to adjust strategy
def _compute_moe_qwen3_layer_operations_strategy_tbo(
layer: torch.nn.Module,
forward_mode: ForwardMode,
) -> OperationsStrategy:
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_qwen3_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
return _compute_moe_qwen3_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
def _compute_moe_qwen3_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
tbo_delta_stages=0,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
layer.mlp.op_dispatch_a,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
],
)
def _compute_moe_qwen3_decode(layer):
return OperationsStrategy(
deep_gemm_num_sms=None,
tbo_delta_stages=2,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
operations.YieldOperation(),
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
operations.YieldOperation(),
layer.mlp.op_dispatch_a,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
operations.YieldOperation(),
],
)
...@@ -356,14 +356,14 @@ def model_forward_maybe_tbo( ...@@ -356,14 +356,14 @@ def model_forward_maybe_tbo(
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator, zero_allocator: Optional[BumpAllocator] = None,
): ):
inputs = dict( inputs = dict(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
residual=residual, residual=residual,
zero_allocator=zero_allocator, **(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
) )
operations_strategy = OperationsStrategy.init_new_tbo( operations_strategy = OperationsStrategy.init_new_tbo(
layers, forward_batch.global_forward_mode layers, forward_batch.global_forward_mode
...@@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs( ...@@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs(
residual: torch.Tensor, residual: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: Optional[BumpAllocator] = None,
) -> List[Dict]: ) -> List[Dict]:
return [ return [
dict( dict(
...@@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs( ...@@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs(
output_forward_batch=output_forward_batch, output_forward_batch=output_forward_batch,
tbo_subbatch_index=tbo_subbatch_index, tbo_subbatch_index=tbo_subbatch_index,
), ),
zero_allocator=zero_allocator, **(
dict(zero_allocator=zero_allocator)
if zero_allocator is not None
else {}
),
) )
for tbo_subbatch_index, output_forward_batch in enumerate( for tbo_subbatch_index, output_forward_batch in enumerate(
forward_batch.tbo_children forward_batch.tbo_children
......
...@@ -9,6 +9,7 @@ from sglang.srt.two_batch_overlap import compute_split_seq_index ...@@ -9,6 +9,7 @@ from sglang.srt.two_batch_overlap import compute_split_seq_index
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
...@@ -104,5 +105,32 @@ class TestTwoBatchOverlapUnitTest(unittest.TestCase): ...@@ -104,5 +105,32 @@ class TestTwoBatchOverlapUnitTest(unittest.TestCase):
self.assertEqual(actual, expect) self.assertEqual(actual, expect)
class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap",
],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment