Unverified Commit 31589e17 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Speed up when having padding tokens two-batch overlap (#6668)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent ae6a5b29
...@@ -454,6 +454,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -454,6 +454,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id, layer_id=self.layer_id,
), ),
......
...@@ -110,7 +110,7 @@ def compute_split_indices_for_cuda_graph_replay( ...@@ -110,7 +110,7 @@ def compute_split_indices_for_cuda_graph_replay(
class TboCudaGraphRunnerPlugin: class TboCudaGraphRunnerPlugin:
def __init__(self): def __init__(self):
pass # TODO add logic here self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]: if not global_server_args_dict["enable_two_batch_overlap"]:
...@@ -124,7 +124,14 @@ class TboCudaGraphRunnerPlugin: ...@@ -124,7 +124,14 @@ class TboCudaGraphRunnerPlugin:
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
TboForwardBatchPreparer.prepare(batch) self._tbo_children_num_token_non_padded[...] = (
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
)
TboForwardBatchPreparer.prepare_raw(
batch,
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
)
def replay_prepare( def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
...@@ -132,7 +139,20 @@ class TboCudaGraphRunnerPlugin: ...@@ -132,7 +139,20 @@ class TboCudaGraphRunnerPlugin:
if not global_server_args_dict["enable_two_batch_overlap"]: if not global_server_args_dict["enable_two_batch_overlap"]:
return return
pass # TODO add logic here tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
# TODO support bs!=num_tokens
cuda_graph_num_tokens=bs,
)
)
self._tbo_children_num_token_non_padded[...] = (
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
tbo_split_token_index=tbo_split_token_index,
num_token_non_padded=num_token_non_padded,
)
)
class TboDPAttentionPreparer: class TboDPAttentionPreparer:
...@@ -207,17 +227,24 @@ class TboDPAttentionPreparer: ...@@ -207,17 +227,24 @@ class TboDPAttentionPreparer:
class TboForwardBatchPreparer: class TboForwardBatchPreparer:
@classmethod @classmethod
def prepare(cls, batch: ForwardBatch): def prepare(cls, batch: ForwardBatch):
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
if batch.tbo_split_seq_index is None: if batch.tbo_split_seq_index is None:
return return
tbo_split_token_index = compute_split_token_index( tbo_children_num_token_non_padded = (
split_seq_index=batch.tbo_split_seq_index, cls.compute_tbo_children_num_token_non_padded(batch)
forward_mode=batch.forward_mode, )
extend_seq_lens=batch.extend_seq_lens_cpu, cls.prepare_raw(
batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
) )
@classmethod
def prepare_raw(
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
):
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
tbo_split_token_index = cls._compute_split_token_index(batch)
if _tbo_debug: if _tbo_debug:
logger.info( logger.info(
f"TboForwardBatchPreparer.prepare " f"TboForwardBatchPreparer.prepare "
...@@ -229,6 +256,10 @@ class TboForwardBatchPreparer: ...@@ -229,6 +256,10 @@ class TboForwardBatchPreparer:
assert isinstance(batch.attn_backend, TboAttnBackend) assert isinstance(batch.attn_backend, TboAttnBackend)
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
tbo_children_num_token_non_padded
)
child_a = cls.filter_batch( child_a = cls.filter_batch(
batch, batch,
start_token_index=0, start_token_index=0,
...@@ -236,6 +267,7 @@ class TboForwardBatchPreparer: ...@@ -236,6 +267,7 @@ class TboForwardBatchPreparer:
start_seq_index=0, start_seq_index=0,
end_seq_index=batch.tbo_split_seq_index, end_seq_index=batch.tbo_split_seq_index,
output_attn_backend=attn_backend_child_a, output_attn_backend=attn_backend_child_a,
out_num_token_non_padded=out_num_token_non_padded_a,
) )
child_b = cls.filter_batch( child_b = cls.filter_batch(
batch, batch,
...@@ -244,6 +276,7 @@ class TboForwardBatchPreparer: ...@@ -244,6 +276,7 @@ class TboForwardBatchPreparer:
start_seq_index=batch.tbo_split_seq_index, start_seq_index=batch.tbo_split_seq_index,
end_seq_index=batch.batch_size, end_seq_index=batch.batch_size,
output_attn_backend=attn_backend_child_b, output_attn_backend=attn_backend_child_b,
out_num_token_non_padded=out_num_token_non_padded_b,
) )
assert batch.tbo_children is None assert batch.tbo_children is None
...@@ -259,9 +292,8 @@ class TboForwardBatchPreparer: ...@@ -259,9 +292,8 @@ class TboForwardBatchPreparer:
start_seq_index: int, start_seq_index: int,
end_seq_index: int, end_seq_index: int,
output_attn_backend: AttentionBackend, output_attn_backend: AttentionBackend,
out_num_token_non_padded: torch.Tensor,
): ):
from sglang.srt.managers.schedule_batch import global_server_args_dict
num_tokens = batch.input_ids.shape[0] num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size num_seqs = batch.batch_size
...@@ -342,6 +374,7 @@ class TboForwardBatchPreparer: ...@@ -342,6 +374,7 @@ class TboForwardBatchPreparer:
), ),
extend_num_tokens=extend_num_tokens, extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend, attn_backend=output_attn_backend,
num_token_non_padded=out_num_token_non_padded,
tbo_split_seq_index=None, tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index), tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None, tbo_children=None,
...@@ -357,7 +390,6 @@ class TboForwardBatchPreparer: ...@@ -357,7 +390,6 @@ class TboForwardBatchPreparer:
top_p_normalized_logprobs=False, top_p_normalized_logprobs=False,
top_p=None, top_p=None,
mm_inputs=None, mm_inputs=None,
num_token_non_padded=None,
) )
) )
...@@ -372,6 +404,32 @@ class TboForwardBatchPreparer: ...@@ -372,6 +404,32 @@ class TboForwardBatchPreparer:
return ForwardBatch(**output_dict) return ForwardBatch(**output_dict)
@classmethod
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
return cls.compute_tbo_children_num_token_non_padded_raw(
tbo_split_token_index=cls._compute_split_token_index(batch),
num_token_non_padded=len(batch.input_ids),
)
@classmethod
def compute_tbo_children_num_token_non_padded_raw(
cls, tbo_split_token_index: int, num_token_non_padded: int
):
# TODO we may make padding on both sub-batches to make it slightly more balanced
value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
device=global_server_args_dict["device"], non_blocking=True
)
@classmethod
def _compute_split_token_index(cls, batch: ForwardBatch):
return compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
)
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend(): if forward_mode.is_extend():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment