Unverified Commit 5e786cca authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support single batch overlap (#10422)

parent 0b9dfba7
...@@ -294,6 +294,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -294,6 +294,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False | | `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False | | `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 | | `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
| `--enable-single-batch-overlap` | Enabling single batch overlap. | False |
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False | | `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 | | `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | | | `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
......
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Union from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch import torch
import triton import triton
...@@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
ceil_div, ceil_div,
dispose_tensor, dispose_tensor,
get_bool_env_var, get_bool_env_var,
get_int_env_var,
is_cuda, is_cuda,
is_hip, is_hip,
is_npu, is_npu,
...@@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE): ...@@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE):
), ),
) )
def moe_impl(self, dispatch_output: DispatchOutput): def moe_impl(
self,
dispatch_output: DispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter: if _use_aiter:
...@@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE): ...@@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE):
return self.forward_deepgemm_contiguous(dispatch_output) return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl(): if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(dispatch_output) return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output) return self.forward_deepgemm_masked(dispatch_output)
else: else:
...@@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE): ...@@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
overlap_args: Optional[Dict[str, Any]] = None,
): ):
return self.deepep_dispatcher.combine( return self.deepep_dispatcher.combine(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_batch=forward_batch, forward_batch=forward_batch,
overlap_args=overlap_args,
) )
def forward_aiter( def forward_aiter(
...@@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE): ...@@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE):
def forward_flashinfer_cutedsl( def forward_flashinfer_cutedsl(
self, self,
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
): ):
hidden_states, _, _, masked_m, _ = dispatch_output hidden_states, _, _, masked_m, _ = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
...@@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE): ...@@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE):
x=hidden_states, x=hidden_states,
masked_m=masked_m, masked_m=masked_m,
moe_runner_config=self.moe_runner_config, moe_runner_config=self.moe_runner_config,
down_gemm_overlap_args=down_gemm_overlap_args,
) )
return output return output
......
...@@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked( ...@@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked(
w2_blockscale: torch.Tensor, w2_blockscale: torch.Tensor,
w2_alpha, w2_alpha,
masked_m: torch.Tensor, masked_m: torch.Tensor,
down_sm_count: Optional[int] = None,
down_signals: Optional[torch.Tensor] = None,
down_start_event: Optional[torch.cuda.Event] = None,
): ):
""" """
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
...@@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked( ...@@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked(
masked_m, masked_m,
) )
if down_start_event is not None:
down_start_event.record()
# Gemm2 # Gemm2
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device) out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
out = out.permute(1, 2, 0) # requirement of kernel out = out.permute(1, 2, 0) # requirement of kernel
...@@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked( ...@@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked(
sf_vec_size=sf_vec_size, sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts), alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha), alpha_dtype=get_cute_dtype(w2_alpha),
**(
dict(
sm_count=down_sm_count,
dst_signals=down_signals,
)
if down_sm_count is not None or down_signals is not None
else {}
),
) # in logical [m, k, l] ) # in logical [m, k, l]
return out.permute(2, 0, 1) return out.permute(2, 0, 1)
from __future__ import annotations from __future__ import annotations
import logging import logging
from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import ( from sglang.srt.layers.moe.token_dispatcher.base import (
...@@ -25,6 +26,9 @@ from sglang.srt.utils import ( ...@@ -25,6 +26,9 @@ from sglang.srt.utils import (
_is_npu = is_npu() _is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.single_batch_overlap import CombineOverlapArgs
try: try:
from deep_ep import Buffer, Config from deep_ep import Buffer, Config
...@@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase: ...@@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
): ):
raise NotImplementedError raise NotImplementedError
...@@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
): ):
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel, deepep_post_reorder_triton_kernel,
...@@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
""" """
self.return_recv_hook = return_recv_hook self.return_recv_hook = return_recv_hook
self.device_module = torch.get_device_module()
def dispatch_a( def dispatch_a(
self, self,
...@@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8 = True use_fp8 = True
buffer = self._get_buffer() buffer = self._get_buffer()
packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
buffer.low_latency_dispatch( buffer.low_latency_dispatch(
hidden_states, hidden_states,
topk_idx, topk_idx,
...@@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
) )
) )
return packed_recv_hidden, packed_recv_count, event, hook return packed_recv_hidden, self.packed_recv_count, event, hook
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
): ):
hidden_states, event, hook = self._combine_core( hidden_states, event, hook = self._combine_core(
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
overlap_args=overlap_args,
) )
return hidden_states, event, hook return hidden_states, event, hook, overlap_args
def combine_b(self, hidden_states, event, hook): def combine_b(self, hidden_states, event, hook, overlap_args):
hook() if self.return_recv_hook else event.current_stream_wait() hook() if self.return_recv_hook else event.current_stream_wait()
if overlap_args is not None:
self.device_module.current_stream().wait_stream(overlap_args.stream)
return hidden_states return hidden_states
def _combine_core( def _combine_core(
...@@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
): ):
buffer = self._get_buffer() buffer = self._get_buffer()
combined_hidden_states, event, hook = buffer.low_latency_combine(
hidden_states, ctx = nullcontext()
topk_idx, if overlap_args is not None:
topk_weights, overlap_args.stream.wait_event(overlap_args.wait_event)
self.handle, ctx = torch.cuda.stream(overlap_args.stream)
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook, with ctx:
) combined_hidden_states, event, hook = buffer.low_latency_combine(
self.handle = None x=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
handle=self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
**(
dict(
overlap=overlap_args.overlap,
src_signals=overlap_args.signal,
src_signal_expect_value=overlap_args.threshold,
)
if overlap_args is not None
else {}
),
)
self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
def _get_buffer(self): def _get_buffer(self):
...@@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
overlap_args: Optional["CombineOverlapArgs"] = None,
): ):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_batch).combine_a( inner_state = self._get_impl(forward_batch).combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
overlap_args=overlap_args,
) )
self._combine_intermediate_state = forward_batch, inner_state self._combine_intermediate_state = forward_batch, inner_state
......
...@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None ...@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None IS_TBO_ENABLED: Optional[bool] = None
IS_SBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None DEEPEP_CONFIG: Optional[str] = None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
...@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs): ...@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
global DEEPEP_MODE global DEEPEP_MODE
global DEEPEP_CONFIG global DEEPEP_CONFIG
global IS_TBO_ENABLED global IS_TBO_ENABLED
global IS_SBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD global TBO_TOKEN_DISTRIBUTION_THRESHOLD
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
...@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs): ...@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode) DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or "" DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap IS_TBO_ENABLED = server_args.enable_two_batch_overlap
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = ( DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
server_args.disable_flashinfer_cutlass_moe_fp4_allgather server_args.disable_flashinfer_cutlass_moe_fp4_allgather
...@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool: ...@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
return IS_TBO_ENABLED return IS_TBO_ENABLED
def is_sbo_enabled() -> bool:
global IS_SBO_ENABLED
if IS_SBO_ENABLED is None:
IS_SBO_ENABLED = False
return IS_SBO_ENABLED
def get_tbo_token_distribution_threshold() -> float: def get_tbo_token_distribution_threshold() -> float:
global TBO_TOKEN_DISTRIBUTION_THRESHOLD global TBO_TOKEN_DISTRIBUTION_THRESHOLD
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None: if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
......
...@@ -47,6 +47,7 @@ if TYPE_CHECKING: ...@@ -47,6 +47,7 @@ if TYPE_CHECKING:
CombineInput, CombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
if is_cuda(): if is_cuda():
from sgl_kernel import scaled_fp4_quant from sgl_kernel import scaled_fp4_quant
...@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
masked_m: torch.Tensor, masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig, moe_runner_config: MoeRunnerConfig,
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
) -> torch.Tensor: ) -> torch.Tensor:
assert ( assert (
moe_runner_config.activation == "silu" moe_runner_config.activation == "silu"
...@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_blockscale=layer.w2_blockscale_swizzled, w2_blockscale=layer.w2_blockscale_swizzled,
w2_alpha=layer.g2_alphas, w2_alpha=layer.g2_alphas,
masked_m=masked_m, masked_m=masked_m,
**(
dict(
down_sm_count=down_gemm_overlap_args.num_sms,
down_signals=down_gemm_overlap_args.signal,
down_start_event=down_gemm_overlap_args.start_event,
)
if down_gemm_overlap_args is not None
else {}
),
) )
return out return out
...@@ -28,6 +28,7 @@ from torch import nn ...@@ -28,6 +28,7 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group, get_pp_group,
...@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher, MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo, model_forward_maybe_tbo,
...@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module):
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states) if not SboFlags.fuse_shared_experts_inside_sbo():
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_weights, topk_idx, _ = self.topk(
hidden_states, hidden_states,
router_logits, router_logits,
...@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module): ...@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module):
hidden_states.device hidden_states.device
) )
final_hidden_states = self.experts( final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_batch=forward_batch, forward_batch=forward_batch,
# SBO args
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
experts=self.experts,
alt_stream=self.alt_stream,
) )
if sbo_shared_output is not None:
shared_output = sbo_shared_output
if shared_output is not None: if shared_output is not None:
x = shared_output x = shared_output
...@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module):
def _forward_shared_experts( def _forward_shared_experts(
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
): ):
if self.num_fused_shared_experts == 0: if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
return self.shared_experts( return self.shared_experts(
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
) )
......
...@@ -377,6 +377,7 @@ class ServerArgs: ...@@ -377,6 +377,7 @@ class ServerArgs:
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_dp_lm_head: bool = False enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False enable_two_batch_overlap: bool = False
enable_single_batch_overlap: bool = False
tbo_token_distribution_threshold: float = 0.48 tbo_token_distribution_threshold: float = 0.48
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
...@@ -2457,6 +2458,11 @@ class ServerArgs: ...@@ -2457,6 +2458,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enabling two micro batches to overlap.", help="Enabling two micro batches to overlap.",
) )
parser.add_argument(
"--enable-single-batch-overlap",
action="store_true",
help="Let computation and communication overlap within one micro batch.",
)
parser.add_argument( parser.add_argument(
"--tbo-token-distribution-threshold", "--tbo-token-distribution-threshold",
type=float, type=float,
......
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
class SboFlags:
# TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
@classmethod
def enable_combine_down_gemm_two_stream_overlap(cls):
return (
is_sbo_enabled()
# currently only cutedsl backend supports it
and get_moe_runner_backend().is_flashinfer_cutedsl()
)
@classmethod
def enable_combine_shared_two_stream_overlap(cls):
return is_sbo_enabled()
@classmethod
def fuse_shared_experts_inside_sbo(cls):
# TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
return cls.enable_combine_shared_two_stream_overlap()
@dataclass
class CombineOverlapArgs:
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
overlap: bool
stream: torch.cuda.Stream
wait_event: torch.cuda.Event
num_sms: int
signal: Optional[torch.Tensor] = None
threshold: int = -1
@dataclass
class DownGemmOverlapArgs:
num_sms: int
signal: torch.Tensor
start_event: torch.cuda.Event
def execute_sbo(
forward_shared_experts: Callable[[], Any],
experts: "DeepEPMoE",
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
alt_stream: Optional = None,
):
shared_output = None
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream)
)
hidden_states = experts.moe_impl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
e.record()
if SboFlags.enable_combine_shared_two_stream_overlap():
# TODO reduce sm for non-deepgemm
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"]
):
shared_output = forward_shared_experts()
hidden_states = experts.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
overlap_args=combine_overlap_args,
)
return hidden_states, shared_output
def _compute_overlap_args(dispatch_output, alt_stream):
if not (
SboFlags.enable_combine_down_gemm_two_stream_overlap()
or SboFlags.enable_combine_shared_two_stream_overlap()
):
return None, None, {}
hidden_states = dispatch_output.hidden_states_fp8
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
compute_num_sms = total_num_sms - communicate_num_sms
assert alt_stream is not None
combine_wait_event = torch.cuda.Event()
combine_overlap_args = CombineOverlapArgs(
overlap=False,
num_sms=communicate_num_sms,
stream=alt_stream,
wait_event=combine_wait_event,
)
meta_overlap_args = dict(
compute_num_sms=compute_num_sms,
)
down_gemm_overlap_args = None
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
# TODO use zero_allocator to remove this `torch.zeros` call
# NOTE ours v2 use uint32 not int32 currently
combine_signal = torch.zeros(
num_local_experts, dtype=torch.uint32, device=hidden_states.device
)
down_gemm_overlap_args = DownGemmOverlapArgs(
signal=combine_signal,
start_event=combine_wait_event,
num_sms=compute_num_sms,
)
combine_overlap_args.overlap = True
combine_overlap_args.signal = combine_signal
combine_overlap_args.threshold = compute_num_sms
else:
meta_overlap_args |= dict(
record_event_after_down=combine_wait_event,
)
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment