Unverified Commit ae6a5b29 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor refactor two-batch overlap (#6682)

parent 4839999b
...@@ -119,24 +119,15 @@ class TboAttnBackend(AttentionBackend): ...@@ -119,24 +119,15 @@ class TboAttnBackend(AttentionBackend):
replay_seq_lens_sum: int = None, replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None, replay_seq_lens_cpu: Optional[torch.Tensor] = None,
): ):
from sglang.srt.model_executor.forward_batch_info import ForwardMode
if fn_name == "init_forward_metadata_capture_cuda_graph": if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently" assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs num_tokens = bs
forward_mode_for_tbo_split = ( tbo_split_seq_index, tbo_split_token_index = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
) forward_mode=forward_mode,
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index( cuda_graph_num_tokens=num_tokens,
forward_mode=forward_mode_for_tbo_split,
num_tokens=num_tokens,
extend_lens=None,
) )
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
) )
num_tokens_child_left = tbo_split_token_index num_tokens_child_left = tbo_split_token_index
......
...@@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import (
TboCudaGraphRunnerUtils, TboCudaGraphRunnerPlugin,
TboForwardBatchPreparer, TboForwardBatchPreparer,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -256,6 +256,7 @@ class CudaGraphRunner: ...@@ -256,6 +256,7 @@ class CudaGraphRunner:
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
self.tbo_plugin = TboCudaGraphRunnerPlugin()
# pipeline parallelism # pipeline parallelism
if self.pp_size > 1: if self.pp_size > 1:
...@@ -481,12 +482,9 @@ class CudaGraphRunner: ...@@ -481,12 +482,9 @@ class CudaGraphRunner:
capture_hidden_mode=self.capture_hidden_mode, capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths, lora_paths=lora_paths,
num_token_non_padded=self.num_token_non_padded, num_token_non_padded=self.num_token_non_padded,
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
self, num_tokens
),
global_forward_mode=self.capture_forward_mode, global_forward_mode=self.capture_forward_mode,
) )
TboForwardBatchPreparer.prepare(forward_batch) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
if lora_paths is not None: if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch) self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
...@@ -581,7 +579,13 @@ class CudaGraphRunner: ...@@ -581,7 +579,13 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions) self.positions[:raw_num_token].copy_(forward_batch.positions)
self.num_token_non_padded[...] = len(forward_batch.input_ids) num_token_non_padded = len(forward_batch.input_ids)
self.num_token_non_padded[...] = num_token_non_padded
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=num_token_non_padded,
)
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(1)
......
...@@ -85,25 +85,54 @@ def compute_split_token_index( ...@@ -85,25 +85,54 @@ def compute_split_token_index(
raise NotImplementedError raise NotImplementedError
def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
tbo_split_seq_index = compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=cuda_graph_num_tokens,
extend_lens=None,
)
tbo_split_token_index = compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
)
return tbo_split_seq_index, tbo_split_token_index
# -------------------------------- Preparation --------------------------------------- # -------------------------------- Preparation ---------------------------------------
class TboCudaGraphRunnerUtils: class TboCudaGraphRunnerPlugin:
@staticmethod def __init__(self):
def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int): pass # TODO add logic here
if that.model_runner.server_args.enable_two_batch_overlap:
tbo_split_seq_index = compute_split_seq_index( def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
forward_mode=that.capture_forward_mode, if not global_server_args_dict["enable_two_batch_overlap"]:
return
batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode,
num_tokens=num_tokens, num_tokens=num_tokens,
extend_lens=None, extend_lens=None,
) )
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert ( assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
tbo_split_seq_index is not None
), f"{that.capture_forward_mode=} {num_tokens=}" TboForwardBatchPreparer.prepare(batch)
else:
tbo_split_seq_index = None def replay_prepare(
return tbo_split_seq_index self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
):
if not global_server_args_dict["enable_two_batch_overlap"]:
return
pass # TODO add logic here
class TboDPAttentionPreparer: class TboDPAttentionPreparer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment