Unverified Commit 33e9bbec authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Make single-batch overlap compatible with offloading (#11614)

parent dcb8f090
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch import torch
from sglang.srt import single_batch_overlap
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
...@@ -167,18 +168,20 @@ class DeepEPMoE(FusedMoE): ...@@ -167,18 +168,20 @@ class DeepEPMoE(FusedMoE):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
forward_shared_experts=None,
alt_stream=None,
): ):
dispatch_output = self.dispatch( # We have to call SBO inside MoE to be compatible with hooks used in offloading
hidden_states, topk_idx, topk_weights, forward_batch return single_batch_overlap.execute_sbo(
) hidden_states=hidden_states,
hidden_states = self.moe_impl(dispatch_output) topk_idx=topk_idx,
hidden_states = self.combine( topk_weights=topk_weights,
hidden_states, forward_batch=forward_batch,
dispatch_output.topk_idx, # SBO args
dispatch_output.topk_weights, experts=self,
forward_batch, forward_shared_experts=forward_shared_experts,
alt_stream=alt_stream,
) )
return hidden_states
def dispatch( def dispatch(
self, self,
......
...@@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module):
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
if not SboFlags.fuse_shared_experts_inside_sbo(): if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_weights, topk_idx, _ = self.topk(
hidden_states, hidden_states,
...@@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module): ...@@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module):
hidden_states.device hidden_states.device
) )
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo( if self._fuse_shared_experts_inside_sbo:
shared_output = None
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_batch=forward_batch, forward_batch=forward_batch,
# SBO args **(
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states), dict(
experts=self.experts, forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream, alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
) )
if sbo_shared_output is not None:
shared_output = sbo_shared_output
if shared_output is not None: if shared_output is not None:
x = shared_output x = shared_output
......
...@@ -42,7 +42,7 @@ class CombineOverlapArgs: ...@@ -42,7 +42,7 @@ class CombineOverlapArgs:
wait_event: torch.cuda.Event wait_event: torch.cuda.Event
num_sms: int num_sms: int
signal: Optional[torch.Tensor] = None signal: Optional[torch.Tensor] = None
threshold: int = -1 threshold: int = 0
@dataclass @dataclass
...@@ -61,8 +61,6 @@ def execute_sbo( ...@@ -61,8 +61,6 @@ def execute_sbo(
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
alt_stream: Optional = None, alt_stream: Optional = None,
): ):
shared_output = None
dispatch_output = experts.dispatch( dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch hidden_states, topk_idx, topk_weights, forward_batch
) )
...@@ -82,7 +80,7 @@ def execute_sbo( ...@@ -82,7 +80,7 @@ def execute_sbo(
with deep_gemm_wrapper.configure_deep_gemm_num_sms( with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"] meta_overlap_args["compute_num_sms"]
): ):
shared_output = forward_shared_experts() forward_shared_experts()
hidden_states = experts.combine( hidden_states = experts.combine(
hidden_states, hidden_states,
...@@ -92,7 +90,7 @@ def execute_sbo( ...@@ -92,7 +90,7 @@ def execute_sbo(
overlap_args=combine_overlap_args, overlap_args=combine_overlap_args,
) )
return hidden_states, shared_output return hidden_states
def _compute_overlap_args(dispatch_output, alt_stream): def _compute_overlap_args(dispatch_output, alt_stream):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment