Unverified Commit 5e786cca authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support single batch overlap (#10422)

parent 0b9dfba7
......@@ -294,6 +294,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
| `--enable-single-batch-overlap` | Enabling single batch overlap. | False |
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
......
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch
import triton
......@@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import (
ceil_div,
dispose_tensor,
get_bool_env_var,
get_int_env_var,
is_cuda,
is_hip,
is_npu,
......@@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE):
),
)
def moe_impl(self, dispatch_output: DispatchOutput):
def moe_impl(
self,
dispatch_output: DispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter:
......@@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE):
return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(dispatch_output)
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
......@@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional[Dict[str, Any]] = None,
):
return self.deepep_dispatcher.combine(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
overlap_args=overlap_args,
)
def forward_aiter(
......@@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE):
def forward_flashinfer_cutedsl(
self,
dispatch_output: DeepEPLLOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
):
hidden_states, _, _, masked_m, _ = dispatch_output
assert self.quant_method is not None
......@@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE):
x=hidden_states,
masked_m=masked_m,
moe_runner_config=self.moe_runner_config,
down_gemm_overlap_args=down_gemm_overlap_args,
)
return output
......
......@@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked(
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
down_sm_count: Optional[int] = None,
down_signals: Optional[torch.Tensor] = None,
down_start_event: Optional[torch.cuda.Event] = None,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
......@@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked(
masked_m,
)
if down_start_event is not None:
down_start_event.record()
# Gemm2
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
out = out.permute(1, 2, 0) # requirement of kernel
......@@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked(
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
**(
dict(
sm_count=down_sm_count,
dst_signals=down_signals,
)
if down_sm_count is not None or down_signals is not None
else {}
),
) # in logical [m, k, l]
return out.permute(2, 0, 1)
from __future__ import annotations
import logging
from contextlib import nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import (
......@@ -25,6 +26,9 @@ from sglang.srt.utils import (
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.single_batch_overlap import CombineOverlapArgs
try:
from deep_ep import Buffer, Config
......@@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase:
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
):
raise NotImplementedError
......@@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
......@@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
self.return_recv_hook = return_recv_hook
self.device_module = torch.get_device_module()
def dispatch_a(
self,
......@@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8 = True
buffer = self._get_buffer()
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
buffer.low_latency_dispatch(
hidden_states,
topk_idx,
......@@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
)
)
return packed_recv_hidden, packed_recv_count, event, hook
return packed_recv_hidden, self.packed_recv_count, event, hook
def combine_a(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
):
hidden_states, event, hook = self._combine_core(
hidden_states,
topk_idx,
topk_weights,
overlap_args=overlap_args,
)
return hidden_states, event, hook
return hidden_states, event, hook, overlap_args
def combine_b(self, hidden_states, event, hook):
def combine_b(self, hidden_states, event, hook, overlap_args):
hook() if self.return_recv_hook else event.current_stream_wait()
if overlap_args is not None:
self.device_module.current_stream().wait_stream(overlap_args.stream)
return hidden_states
def _combine_core(
......@@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"],
):
buffer = self._get_buffer()
combined_hidden_states, event, hook = buffer.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
self.handle = None
ctx = nullcontext()
if overlap_args is not None:
overlap_args.stream.wait_event(overlap_args.wait_event)
ctx = torch.cuda.stream(overlap_args.stream)
with ctx:
combined_hidden_states, event, hook = buffer.low_latency_combine(
x=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
handle=self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
**(
dict(
overlap=overlap_args.overlap,
src_signals=overlap_args.signal,
src_signal_expect_value=overlap_args.threshold,
)
if overlap_args is not None
else {}
),
)
self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook
def _get_buffer(self):
......@@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional["CombineOverlapArgs"] = None,
):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_batch).combine_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
overlap_args=overlap_args,
)
self._combine_intermediate_state = forward_batch, inner_state
......
......@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
IS_SBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
......@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
global IS_SBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
......@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
......@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
return IS_TBO_ENABLED
def is_sbo_enabled() -> bool:
global IS_SBO_ENABLED
if IS_SBO_ENABLED is None:
IS_SBO_ENABLED = False
return IS_SBO_ENABLED
def get_tbo_token_distribution_threshold() -> float:
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
......
......@@ -47,6 +47,7 @@ if TYPE_CHECKING:
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
if is_cuda():
from sgl_kernel import scaled_fp4_quant
......@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig,
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
) -> torch.Tensor:
assert (
moe_runner_config.activation == "silu"
......@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=masked_m,
**(
dict(
down_sm_count=down_gemm_overlap_args.num_sms,
down_signals=down_gemm_overlap_args.signal,
down_start_event=down_gemm_overlap_args.start_event,
)
if down_gemm_overlap_args is not None
else {}
),
)
return out
......@@ -28,6 +28,7 @@ from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
......@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
......@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module):
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
if not SboFlags.fuse_shared_experts_inside_sbo():
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
......@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module):
hidden_states.device
)
final_hidden_states = self.experts(
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
experts=self.experts,
alt_stream=self.alt_stream,
)
if sbo_shared_output is not None:
shared_output = sbo_shared_output
if shared_output is not None:
x = shared_output
......@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module):
def _forward_shared_experts(
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
):
if self.num_fused_shared_experts == 0:
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
return self.shared_experts(
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
)
......
......@@ -377,6 +377,7 @@ class ServerArgs:
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False
enable_single_batch_overlap: bool = False
tbo_token_distribution_threshold: float = 0.48
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
......@@ -2457,6 +2458,11 @@ class ServerArgs:
action="store_true",
help="Enabling two micro batches to overlap.",
)
parser.add_argument(
"--enable-single-batch-overlap",
action="store_true",
help="Let computation and communication overlap within one micro batch.",
)
parser.add_argument(
"--tbo-token-distribution-threshold",
type=float,
......
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
class SboFlags:
# TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
@classmethod
def enable_combine_down_gemm_two_stream_overlap(cls):
return (
is_sbo_enabled()
# currently only cutedsl backend supports it
and get_moe_runner_backend().is_flashinfer_cutedsl()
)
@classmethod
def enable_combine_shared_two_stream_overlap(cls):
return is_sbo_enabled()
@classmethod
def fuse_shared_experts_inside_sbo(cls):
# TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
return cls.enable_combine_shared_two_stream_overlap()
@dataclass
class CombineOverlapArgs:
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
overlap: bool
stream: torch.cuda.Stream
wait_event: torch.cuda.Event
num_sms: int
signal: Optional[torch.Tensor] = None
threshold: int = -1
@dataclass
class DownGemmOverlapArgs:
num_sms: int
signal: torch.Tensor
start_event: torch.cuda.Event
def execute_sbo(
forward_shared_experts: Callable[[], Any],
experts: "DeepEPMoE",
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
alt_stream: Optional = None,
):
shared_output = None
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream)
)
hidden_states = experts.moe_impl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
e.record()
if SboFlags.enable_combine_shared_two_stream_overlap():
# TODO reduce sm for non-deepgemm
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
meta_overlap_args["compute_num_sms"]
):
shared_output = forward_shared_experts()
hidden_states = experts.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
overlap_args=combine_overlap_args,
)
return hidden_states, shared_output
def _compute_overlap_args(dispatch_output, alt_stream):
if not (
SboFlags.enable_combine_down_gemm_two_stream_overlap()
or SboFlags.enable_combine_shared_two_stream_overlap()
):
return None, None, {}
hidden_states = dispatch_output.hidden_states_fp8
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
compute_num_sms = total_num_sms - communicate_num_sms
assert alt_stream is not None
combine_wait_event = torch.cuda.Event()
combine_overlap_args = CombineOverlapArgs(
overlap=False,
num_sms=communicate_num_sms,
stream=alt_stream,
wait_event=combine_wait_event,
)
meta_overlap_args = dict(
compute_num_sms=compute_num_sms,
)
down_gemm_overlap_args = None
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
# TODO use zero_allocator to remove this `torch.zeros` call
# NOTE ours v2 use uint32 not int32 currently
combine_signal = torch.zeros(
num_local_experts, dtype=torch.uint32, device=hidden_states.device
)
down_gemm_overlap_args = DownGemmOverlapArgs(
signal=combine_signal,
start_event=combine_wait_event,
num_sms=compute_num_sms,
)
combine_overlap_args.overlap = True
combine_overlap_args.signal = combine_signal
combine_overlap_args.threshold = compute_num_sms
else:
meta_overlap_args |= dict(
record_event_after_down=combine_wait_event,
)
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment