Unverified Commit 33e9bbec authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Make single-batch overlap compatible with offloading (#11614)

parent dcb8f090
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
from sglang.srt import single_batch_overlap
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
......@@ -167,18 +168,20 @@ class DeepEPMoE(FusedMoE):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
forward_shared_experts=None,
alt_stream=None,
):
dispatch_output = self.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
hidden_states = self.moe_impl(dispatch_output)
hidden_states = self.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args
experts=self,
forward_shared_experts=forward_shared_experts,
alt_stream=alt_stream,
)
return hidden_states
def dispatch(
self,
......
......@@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module):
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if not SboFlags.fuse_shared_experts_inside_sbo():
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
......@@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module):
hidden_states.device
)
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
if self._fuse_shared_experts_inside_sbo:
shared_output = None
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
experts=self.experts,
alt_stream=self.alt_stream,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if sbo_shared_output is not None:
shared_output = sbo_shared_output
if shared_output is not None:
x = shared_output
......
......@@ -42,7 +42,7 @@ class CombineOverlapArgs:
wait_event: torch.cuda.Event
num_sms: int
signal: Optional[torch.Tensor] = None
threshold: int = -1
threshold: int = 0
@dataclass
......@@ -61,8 +61,6 @@ def execute_sbo(
forward_batch: ForwardBatch,
alt_stream: Optional = None,
):
shared_output = None
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
......@@ -82,7 +80,7 @@ def execute_sbo(
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"]
):
shared_output = forward_shared_experts()
forward_shared_experts()
hidden_states = experts.combine(
hidden_states,
......@@ -92,7 +90,7 @@ def execute_sbo(
overlap_args=combine_overlap_args,
)
return hidden_states, shared_output
return hidden_states
def _compute_overlap_args(dispatch_output, alt_stream):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment