Unverified Commit 0d477880 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support overlapping two batches (#4068)

parent f4560373
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
from sglang.srt import two_batch_overlap
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
class TboAttnBackend(AttentionBackend):
def __init__(self, primary: AttentionBackend, children: List[AttentionBackend]):
super().__init__()
self.primary = primary
self.children = children
@classmethod
def init_new(cls, creator: Callable[[], AttentionBackend]):
return cls(
primary=creator(),
children=[creator() for _ in range(2)],
)
def init_forward_metadata(self, forward_batch: "ForwardBatch"):
self.primary.init_forward_metadata(forward_batch=forward_batch)
if forward_batch.tbo_children is not None:
for child, forward_batch_child in zip(
self.children, forward_batch.tbo_children, strict=True
):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.primary.init_forward_metadata_capture_cuda_graph(
bs=bs,
num_tokens=num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
)
self._init_forward_metadata_cuda_graph_children(
fn_name="init_forward_metadata_capture_cuda_graph",
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
capture_num_tokens=num_tokens,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.primary.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
seq_lens_cpu=seq_lens_cpu,
)
self._init_forward_metadata_cuda_graph_children(
fn_name="init_forward_metadata_replay_cuda_graph",
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
replay_seq_lens_sum=seq_lens_sum,
replay_seq_lens_cpu=seq_lens_cpu,
)
def _init_forward_metadata_cuda_graph_children(
self,
fn_name: str,
# common args
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
from sglang.srt.model_executor.forward_batch_info import ForwardMode
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=num_tokens,
extend_lens=None,
)
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
)
num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left
bs_child_right = num_tokens_child_right
assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0
), f"{num_tokens_child_left=} {num_tokens_child_right=} {forward_mode=} {num_tokens=}"
common_pre_split_args = dict(
fn_name=fn_name,
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
capture_num_tokens=capture_num_tokens,
replay_seq_lens_sum=replay_seq_lens_sum,
replay_seq_lens_cpu=replay_seq_lens_cpu,
)
args_left = _init_forward_metadata_cuda_graph_split(
output_bs=bs_child_left,
seq_slice=slice(None, tbo_split_seq_index),
**common_pre_split_args,
)
args_right = _init_forward_metadata_cuda_graph_split(
output_bs=bs_child_right,
seq_slice=slice(tbo_split_seq_index, None),
**common_pre_split_args,
)
child_left, child_right = self.children
getattr(child_left, fn_name)(**args_left)
getattr(child_right, fn_name)(**args_right)
def get_cuda_graph_seq_len_fill_value(self):
ans = self.primary.get_cuda_graph_seq_len_fill_value()
for child in self.children:
assert ans == child.get_cuda_graph_seq_len_fill_value()
return ans
def forward_extend(self, *args, **kwargs):
return self.primary.forward_extend(*args, **kwargs)
def forward_decode(self, *args, **kwargs):
return self.primary.forward_decode(*args, **kwargs)
def _init_forward_metadata_cuda_graph_split(
fn_name: str,
seq_slice: slice,
output_bs: int,
# common args
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet"
ans = dict(
bs=output_bs,
req_pool_indices=req_pool_indices[seq_slice],
seq_lens=seq_lens[seq_slice],
# directly forward
forward_mode=forward_mode,
# ignore
encoder_lens=None,
spec_info=None,
)
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
ans.update(
dict(
num_tokens=output_bs,
)
)
elif fn_name == "init_forward_metadata_replay_cuda_graph":
output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice]
ans.update(
dict(
seq_lens_sum=output_seq_lens_cpu.sum().item(),
seq_lens_cpu=output_seq_lens_cpu,
)
)
else:
raise NotImplementedError
return ans
......@@ -391,3 +391,16 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
RuntimeCache.get = __patched_func
yield
RuntimeCache.get = origin_func
@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None:
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
......@@ -78,6 +78,7 @@ global_server_args_dict = {
"disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"deepep_config": ServerArgs.deepep_config,
......@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None
# For processing logprobs
return_logprob: bool = False
......@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = self.seq_lens.cpu()
else:
......@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
seq_lens_cpu=seq_lens_cpu,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
......@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]
# For extend
extend_num_tokens: Optional[int]
......
......@@ -34,6 +34,7 @@ import zmq
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
......@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
broadcast_pyobj,
configure_logger,
......@@ -1648,6 +1651,9 @@ class Scheduler(
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
)
@staticmethod
......@@ -1661,6 +1667,9 @@ class Scheduler(
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
):
# Check if other DP workers have running batches
if local_batch is None:
......@@ -1696,17 +1705,26 @@ class Scheduler(
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
tbo_preparer = TboDPAttentionPreparer()
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
num_tokens_for_logprob,
is_extend_in_batch,
*tbo_preparer.prepare_all_gather(
local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
),
],
dtype=torch.int64,
)
global_info = torch.empty(
(dp_size, attn_tp_size, 4),
(dp_size, attn_tp_size, 6),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
......@@ -1719,6 +1737,10 @@ class Scheduler(
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
global_info[:, :, 4:6]
)
if local_batch is None and max(global_num_tokens) > 0:
local_batch = get_idle_batch()
......@@ -1732,6 +1754,8 @@ class Scheduler(
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode
# Check forward mode for cuda graph
if not disable_cuda_graph:
......
......@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import tqdm
from sglang.srt import two_batch_overlap
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
......@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import (
TboCudaGraphRunnerUtils,
TboForwardBatchPreparer,
)
from sglang.srt.utils import (
get_available_gpu_memory,
get_device_memory_capacity,
......@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner.req_to_token_pool.size
]
if server_args.enable_two_batch_overlap:
capture_bs = [bs for bs in capture_bs if bs >= 2]
if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
if max(capture_bs) < server_args.cuda_graph_max_bs:
......@@ -349,7 +357,14 @@ class CudaGraphRunner:
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported
is_tbo_supported = (
forward_batch.can_run_tbo
if self.model_runner.server_args.enable_two_batch_overlap
else True
)
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
def capture(self):
with graph_capture() as graph_capture_context:
......@@ -466,7 +481,12 @@ class CudaGraphRunner:
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
num_token_non_padded=self.num_token_non_padded,
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
self, num_tokens
),
global_forward_mode=self.capture_forward_mode,
)
TboForwardBatchPreparer.prepare(forward_batch)
if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
......
......@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import triton
......@@ -239,6 +240,7 @@ class ForwardBatch:
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None
# Speculative decoding
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
......@@ -252,12 +254,18 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None
tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None
@classmethod
def init_new(
cls,
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
......@@ -281,6 +289,7 @@ class ForwardBatch:
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
req_to_token_pool=model_runner.req_to_token_pool,
......@@ -294,6 +303,7 @@ class ForwardBatch:
num_token_non_padded=torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True),
tbo_split_seq_index=batch.tbo_split_seq_index,
)
# For DP attention
......@@ -316,6 +326,7 @@ class ForwardBatch:
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret)
return ret
# Override the positions with spec_info
......@@ -364,6 +375,8 @@ class ForwardBatch:
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
TboForwardBatchPreparer.prepare(ret)
return ret
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
......@@ -588,6 +601,10 @@ class ForwardBatch:
# Precompute the kv indices for each chunk
self.prepare_chunked_kv_indices(device)
@property
def can_run_tbo(self):
return self.tbo_split_seq_index is not None
class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
......
......@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
......@@ -198,6 +199,7 @@ class ModelRunner:
"disable_radix_cache": server_args.disable_radix_cache,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_two_batch_overlap": server_args.enable_two_batch_overlap,
"enable_dp_lm_head": server_args.enable_dp_lm_head,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
......@@ -994,6 +996,13 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.enable_two_batch_overlap:
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
else:
self.attn_backend = self._get_attention_backend()
# TODO unify with 6338
def _get_attention_backend(self):
if self.server_args.attention_backend == "flashinfer":
if not self.use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
......@@ -1003,17 +1012,17 @@ class ModelRunner:
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
return FlashInferAttnBackend(self)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
return FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "aiter":
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
self.attn_backend = AiterAttnBackend(self)
return AiterAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
......@@ -1028,21 +1037,21 @@ class ModelRunner:
DoubleSparseAttnBackend,
)
self.attn_backend = DoubleSparseAttnBackend(self)
return DoubleSparseAttnBackend(self)
else:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
self.attn_backend = TritonAttnBackend(self)
return TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
from sglang.srt.layers.attention.torch_native_backend import (
TorchNativeAttnBackend,
)
self.attn_backend = TorchNativeAttnBackend(self)
return TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
self.attn_backend = FlashMLABackend(self)
return FlashMLABackend(self)
elif self.server_args.attention_backend == "fa3":
assert (
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
......@@ -1054,13 +1063,13 @@ class ModelRunner:
FlashAttentionBackend,
)
self.attn_backend = FlashAttentionBackend(self)
return FlashAttentionBackend(self)
elif self.server_args.attention_backend == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend,
)
self.attn_backend = CutlassMLABackend(self)
return CutlassMLABackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
......
......@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.operations import execute_operations
from sglang.srt.operations_strategy import compute_layer_operations
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
......@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
self.config = config
self.layer_id = layer_id
if self.tp_size > config.n_routed_experts:
......@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
else None
)
self.deepep_dispatcher = DeepEPDispatcher(
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
......@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
async_finish=True,
return_recv_hook=True,
)
@property
def _enable_deepep_moe(self):
return global_server_args_dict["enable_deepep_moe"]
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
def get_moe_weights(self):
return [
......@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
return None
def op_gate(self, state):
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
if is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
......@@ -432,115 +433,105 @@ class DeepseekV2MoE(nn.Module):
state.router_logits = None
def op_shared_experts(self, state):
if (self.n_share_experts_fusion == 0) and (
(not self._enable_deepep_moe)
or is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
)
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty(
state.forward_batch.forward_mode, hidden_states_mlp_input
):
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
state.shared_output = self.shared_experts(hidden_states_mlp_input)
else:
state.shared_output = None
def op_select_experts(self, state):
router_logits = state.router_logits
router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input
if self._enable_deepep_moe:
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if router_logits is not None:
state.topk_weights_local, state.topk_idx_local = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"),
hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_dispatch_b(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b()
if self.ep_size > 1:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
if self._enable_deepep_moe:
state.pop("router_logits")
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
)
else:
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_mlp_input"),
router_logits=state.pop("router_logits"),
)
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
)
def op_combine_a(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
if self.ep_size > 1:
self.deepep_dispatcher.combine_a(
state.pop("hidden_states_experts_output"),
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_combine_b(self, state):
if self._enable_deepep_moe and (self.ep_size > 1):
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_output(self, state):
final_hidden_states = (
state.pop("hidden_states_after_combine")
if self._enable_deepep_moe
else state.pop("hidden_states_experts_output")
)
final_hidden_states = state.pop("hidden_states_after_combine")
final_hidden_states *= self.routed_scaling_factor
if (s := state.pop("shared_output")) is not None:
final_hidden_states = final_hidden_states + s
if (not self._enable_deepep_moe) and (self.tp_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
state.hidden_states_mlp_output = final_hidden_states
......@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
tbo_subbatch_index: Optional[int] = None,
):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
......@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch,
positions=positions,
zero_allocator=zero_allocator,
tbo_subbatch_index=tbo_subbatch_index,
)
)
......@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
state.forward_batch,
)
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
return hidden_states, residual
output = dict(
positions=state.positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
tbo_subbatch_index=state.tbo_subbatch_index,
)
state.clear(
expect_keys={
"positions",
"forward_batch",
"zero_allocator",
"tbo_subbatch_index",
}
)
return output
class DeepseekV2Model(nn.Module):
......@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
super().__init__()
self.padding_id = config.pad_token_id
self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
......@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
total_num_layers = len(self.layers)
device = input_embeds.device if input_embeds is not None else input_ids.device
zero_allocator = BumpAllocator(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
device=device,
)
if input_embeds is None:
......@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
normal_num_layers = (
self.first_k_dense_replace
if forward_batch.can_run_tbo
else total_num_layers
)
for i in range(normal_num_layers):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator
)
if normal_num_layers != total_num_layers:
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers[normal_num_layers:],
enable_tbo=True,
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
zero_allocator=zero_allocator,
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
......@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
......
......@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
def execute_operations(inputs, operations):
stages = _convert_operations_to_stages(decorate_operations(operations))
stages = _convert_operations_to_stages(operations)
executor = _StageExecutor("primary", stages, inputs=inputs)
for _ in range(executor.num_stages):
executor.next()
......@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
return executor.output
def execute_overlapped_operations(
inputs_arr: Sequence,
operations_arr: Sequence,
delta_stages: Sequence[int],
) -> Sequence:
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
inputs_a, inputs_b = inputs_arr
operations_a, operations_b = operations_arr
delta_stage_a, delta_stage_b = delta_stages
assert delta_stage_a == 0
delta_stage = delta_stage_b
stages_a = _convert_operations_to_stages(operations_a)
stages_b = _convert_operations_to_stages(operations_b)
executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
for _ in range(delta_stage):
executor_a.next()
for _ in range(executor_a.num_stages - delta_stage):
executor_a.next()
executor_b.next()
for _ in range(delta_stage):
executor_b.next()
assert executor_a.done and executor_b.done
return [executor_a.output, executor_b.output]
class YieldOperation:
pass
......@@ -109,6 +140,9 @@ class _StateDict:
for k, v in values.items():
setattr(self, k, v)
def get(self, item):
return self._data.get(item)
def clear(self, expect_keys: Sequence[str]):
if set(self._data.keys()) != set(expect_keys):
raise Exception(
......@@ -119,6 +153,7 @@ class _StateDict:
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
operations = _decorate_operations(operations)
operation_chunks = list(
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
)
......@@ -140,7 +175,7 @@ def _chunk_by_separator(
yield pending_items
def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
return [_decorate_operation(op, debug_name_prefix) for op in operations]
......
from dataclasses import dataclass
from typing import List, Optional
import torch
from sglang.srt import operations
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.operations import Operation
@dataclass
class OperationsStrategy:
operations: List[Operation]
deep_gemm_num_sms: Optional[int] = None
tbo_delta_stages: Optional[int] = None
@classmethod
def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy":
return OperationsStrategy(
operations=[x for item in items for x in item.operations],
deep_gemm_num_sms=_assert_all_same(
[item.deep_gemm_num_sms for item in items]
),
tbo_delta_stages=_assert_all_same(
[item.tbo_delta_stages for item in items]
),
)
@staticmethod
def init_new_tbo(
layers: torch.nn.ModuleList,
forward_mode: ForwardMode,
) -> "OperationsStrategy":
return OperationsStrategy.concat(
[
_compute_layer_operations_strategy_tbo(layer, forward_mode)
for layer in layers
]
)
def compute_layer_operations(
def _assert_all_same(items: List):
assert all(item == items[0] for item in items)
return items[0]
# TODO can refactor to make it more fancy if we have more complex strategies
def _compute_layer_operations_strategy_tbo(
layer: torch.nn.Module,
):
if not layer.is_layer_sparse:
return [
forward_mode: ForwardMode,
) -> OperationsStrategy:
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_deepseek_blog_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
return _compute_moe_deepseek_blog_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
def _compute_moe_deepseek_blog_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
tbo_delta_stages=0,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
layer.mlp.op_dispatch_a,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_shared_experts,
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
],
)
def _compute_moe_deepseek_blog_decode(layer):
return OperationsStrategy(
deep_gemm_num_sms=None,
tbo_delta_stages=2,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
operations.YieldOperation(),
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.op_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
operations.YieldOperation(),
layer.mlp.op_dispatch_a,
layer.mlp.op_shared_experts,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
]
# Will add TBO operation orders here
return [
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_shared_experts,
layer.mlp.op_select_experts,
layer.mlp.op_dispatch_a,
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
]
operations.YieldOperation(),
],
)
......@@ -167,6 +167,7 @@ class ServerArgs:
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
......@@ -1144,6 +1145,11 @@ class ServerArgs:
action="store_true",
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-two-batch-overlap",
action="store_true",
help="Enabling two micro batches to overlap.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
......
import dataclasses
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.utils import BumpAllocator, DeepEPMode
if TYPE_CHECKING:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
# -------------------------------- Compute Basic Info ---------------------------------------
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def compute_split_seq_index(
forward_mode: "ForwardMode",
num_tokens: int,
extend_lens: Optional[Sequence[int]],
) -> Optional[int]:
if forward_mode.is_extend():
assert extend_lens is not None
return _split_array_by_half_sum(extend_lens)
elif forward_mode.is_decode():
return num_tokens // 2
elif forward_mode.is_idle():
assert num_tokens == 0
return 0
else:
raise NotImplementedError
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
overall_sum = sum(arr)
accumulator, split_index = 0, 0
for value in arr[:-1]:
accumulator += value
split_index += 1
if accumulator >= overall_sum // 2:
break
return split_index
def compute_split_token_index(
split_seq_index: int,
forward_mode: "ForwardMode",
extend_seq_lens: Optional[Sequence[int]],
) -> int:
if forward_mode.is_extend():
assert extend_seq_lens is not None
return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_decode():
return split_seq_index
elif forward_mode.is_idle():
assert split_seq_index == 0
return 0
else:
raise NotImplementedError
# -------------------------------- Preparation ---------------------------------------
class TboCudaGraphRunnerUtils:
@staticmethod
def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int):
if that.model_runner.server_args.enable_two_batch_overlap:
tbo_split_seq_index = compute_split_seq_index(
forward_mode=that.capture_forward_mode,
num_tokens=num_tokens,
extend_lens=None,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert (
tbo_split_seq_index is not None
), f"{that.capture_forward_mode=} {num_tokens=}"
else:
tbo_split_seq_index = None
return tbo_split_seq_index
class TboDPAttentionPreparer:
def prepare_all_gather(
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
):
self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None:
self.local_tbo_split_seq_index = compute_split_seq_index(
forward_mode=local_batch.forward_mode,
num_tokens=local_batch.input_ids.shape[0],
extend_lens=local_batch.extend_lens,
)
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
local_batch.forward_mode.is_extend()
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.low_latency)
)
else:
self.local_tbo_split_seq_index = 0
local_can_run_tbo = True
local_forward_mode = self._compute_local_forward_mode(local_batch)
return local_can_run_tbo, local_forward_mode
def compute_output(self, partial_global_info):
local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist())
forward_modes = partial_global_info[:, 0, 1].tolist()
global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
forward_modes
)
can_run_tbo = (
self.enable_two_batch_overlap
and local_can_run_tbo_aggregated
and forward_mode_agree
)
tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
global_forward_mode = global_forward_mode if can_run_tbo else None
return tbo_split_seq_index, global_forward_mode
@staticmethod
def _compute_local_forward_mode(local_batch):
return (
local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
).value
@staticmethod
def _compute_global_forward_mode(forward_modes):
converted_forward_modes = [
ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x
for x in forward_modes
]
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
converted_forward_modes
)
global_forward_mode = (
ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None
)
return global_forward_mode, forward_mode_agree
@staticmethod
def _is_all_same(x):
return all(value == x[0] for value in x)
class TboForwardBatchPreparer:
@classmethod
def prepare(cls, batch: ForwardBatch):
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
if batch.tbo_split_seq_index is None:
return
tbo_split_token_index = compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
)
assert isinstance(batch.attn_backend, TboAttnBackend)
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
child_a = cls.filter_batch(
batch,
start_token_index=0,
end_token_index=tbo_split_token_index,
start_seq_index=0,
end_seq_index=batch.tbo_split_seq_index,
output_attn_backend=attn_backend_child_a,
)
child_b = cls.filter_batch(
batch,
start_token_index=tbo_split_token_index,
end_token_index=batch.input_ids.shape[0],
start_seq_index=batch.tbo_split_seq_index,
end_seq_index=batch.batch_size,
output_attn_backend=attn_backend_child_b,
)
assert batch.tbo_children is None
batch.tbo_children = [child_a, child_b]
@classmethod
def filter_batch(
cls,
batch: ForwardBatch,
*,
start_token_index: int,
end_token_index: int,
start_seq_index: int,
end_seq_index: int,
output_attn_backend: AttentionBackend,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict
num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size
output_dict = dict()
for key in [
"input_ids",
"positions",
"out_cache_loc",
]:
old_value = getattr(batch, key)
assert (
old_value.shape[0] == num_tokens
), f"{key=} {old_value=} {num_tokens=} {batch=}"
output_dict[key] = old_value[start_token_index:end_token_index]
for key in [
"req_pool_indices",
"seq_lens",
"seq_lens_cpu",
"extend_seq_lens",
"extend_prefix_lens",
"extend_start_loc",
"extend_prefix_lens_cpu",
"extend_seq_lens_cpu",
"extend_logprob_start_lens_cpu",
"lora_paths",
]:
old_value = getattr(batch, key)
if old_value is None:
continue
assert (
len(old_value) == num_seqs
), f"{key=} {old_value=} {num_seqs=} {batch=}"
output_dict[key] = old_value[start_seq_index:end_seq_index]
for key in [
"forward_mode",
"return_logprob",
"req_to_token_pool",
"token_to_kv_pool",
"can_run_dp_cuda_graph",
"global_forward_mode",
"spec_info",
"spec_algorithm",
"capture_hidden_mode",
"padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care
]:
output_dict[key] = getattr(batch, key)
assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens
), f"{batch=}"
extend_num_tokens = _compute_extend_num_tokens(
output_dict["input_ids"], output_dict["forward_mode"]
)
# TODO improve, e.g. unify w/ `init_raw`
if global_server_args_dict["moe_dense_tp_size"] == 1:
sum_len = end_token_index - start_token_index
gathered_buffer = torch.zeros(
(sum_len, batch.gathered_buffer.shape[1]),
dtype=batch.gathered_buffer.dtype,
device=batch.gathered_buffer.device,
)
else:
gathered_buffer = None
output_dict.update(
dict(
batch_size=end_seq_index - start_seq_index,
seq_lens_sum=(
output_dict["seq_lens_cpu"].sum()
if "seq_lens_cpu" in output_dict
else None
),
extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend,
tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None,
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
gathered_buffer=gathered_buffer,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
sampling_info=None,
# For logits and logprobs post processing, thus we do not care
temp_scaled_logprobs=False,
temperature=None,
top_p_normalized_logprobs=False,
top_p=None,
mm_inputs=None,
num_token_non_padded=None,
)
)
errors = []
for field in dataclasses.fields(ForwardBatch):
if getattr(batch, field.name) is not None and field.name not in output_dict:
errors.append(
f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
)
if len(errors) > 0:
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))
return ForwardBatch(**output_dict)
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend():
return input_ids.shape[0]
elif forward_mode.is_decode() or forward_mode.is_idle():
return None
raise NotImplementedError
# -------------------------------- Execution ---------------------------------------
def model_forward_maybe_tbo(
layers,
enable_tbo: bool,
positions: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
):
inputs = dict(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
zero_allocator=zero_allocator,
)
operations_strategy = OperationsStrategy.init_new_tbo(
layers, forward_batch.global_forward_mode
)
if enable_tbo:
return _model_forward_tbo(inputs, operations_strategy)
else:
return _model_forward_non_tbo(inputs, operations_strategy)
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy):
# The attn_tp_size!=1 case is not yet extracted to master
assert get_attention_tp_size() == 1
inputs_arr = _model_forward_tbo_split_inputs(**inputs)
del inputs
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
outputs_arr = execute_overlapped_operations(
inputs_arr=inputs_arr,
operations_arr=[operations_strategy.operations] * 2,
delta_stages=[0, operations_strategy.tbo_delta_stages],
)
return _model_forward_tbo_merge_outputs(*outputs_arr)
def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
outputs = execute_operations(inputs, operations_strategy.operations)
return outputs["hidden_states"], outputs["residual"]
def _model_forward_tbo_split_inputs(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> List[Dict]:
return [
dict(
**_model_forward_filter_inputs(
hidden_states=hidden_states,
residual=residual,
positions=positions,
output_forward_batch=output_forward_batch,
tbo_subbatch_index=tbo_subbatch_index,
),
zero_allocator=zero_allocator,
)
for tbo_subbatch_index, output_forward_batch in enumerate(
forward_batch.tbo_children
)
]
def _model_forward_filter_inputs(
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
output_forward_batch: ForwardBatch,
tbo_subbatch_index: int,
) -> Dict:
token_slice = slice(*output_forward_batch.tbo_parent_token_range)
return dict(
hidden_states=hidden_states[token_slice],
residual=None if residual is None else residual[token_slice],
positions=positions[token_slice],
forward_batch=output_forward_batch,
tbo_subbatch_index=tbo_subbatch_index,
)
def _model_forward_tbo_merge_outputs(output_a, output_b):
def _handle_key(name):
value_a = output_a[name]
value_b = output_b[name]
assert (value_a is None) == (value_b is None)
if value_a is None:
return None
return torch.concat([value_a, value_b], dim=0)
return _handle_key("hidden_states"), _handle_key("residual")
# -------------------------------- Utilities and wrappers ---------------------------------------
class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs):
num_inner_dispatchers = (
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
)
self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
def dispatch(self, **kwargs):
return self._execute("dispatch", **kwargs)
def dispatch_a(self, **kwargs):
return self._execute("dispatch_a", **kwargs)
def dispatch_b(self, **kwargs):
return self._execute("dispatch_b", **kwargs)
def combine(self, **kwargs):
return self._execute("combine", **kwargs)
def combine_a(self, **kwargs):
return self._execute("combine_a", **kwargs)
def combine_b(self, **kwargs):
return self._execute("combine_b", **kwargs)
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestTwoBatchOverlap(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap",
],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_generate_single_prompt(self):
response = requests.post(
self.base_url + "/generate",
# we use an uncommon start to minimise the chance that the cache is hit by chance
json={
"text": "_ 1+1=2, 1+2=3, 1+3=4, 1+4=",
"sampling_params": {"temperature": 0, "max_new_tokens": 8},
},
)
print(f"{response.json()=}")
self.assertEquals(response.json()["text"], "5, 1+5=6")
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment