Unverified Commit 41650b0d authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

feat: support compatibility between MTP and two-batch-overlap (#7225)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 1b951620
......@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs
assert (
capture_num_tokens == bs * token_num_per_seq
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
num_tokens = bs * token_num_per_seq
tbo_split_seq_index, tbo_split_token_index = (
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
cuda_graph_num_tokens=num_tokens,
spec_info=spec_info,
)
)
num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left
bs_child_right = num_tokens_child_right
bs_child_left = tbo_split_seq_index
bs_child_right = bs - bs_child_left
assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0
......@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[EagleVerifyInput],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet"
if spec_info is not None:
output_spec_info = two_batch_overlap.split_spec_info(
spec_info=spec_info,
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
start_token_index=(
seq_slice.start * token_num_per_seq
if seq_slice.start is not None
else 0
),
end_token_index=(
seq_slice.stop * token_num_per_seq
if seq_slice.stop is not None
else bs * token_num_per_seq
),
)
else:
output_spec_info = None
ans = dict(
bs=output_bs,
req_pool_indices=req_pool_indices[seq_slice],
......@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
forward_mode=forward_mode,
# ignore
encoder_lens=None,
spec_info=None,
spec_info=output_spec_info,
)
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
assert (
capture_num_tokens == bs * token_num_per_seq
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
ans.update(
dict(
num_tokens=output_bs,
num_tokens=output_bs * token_num_per_seq,
)
)
elif fn_name == "init_forward_metadata_replay_cuda_graph":
......
......@@ -679,6 +679,7 @@ class CudaGraphRunner:
forward_mode=self.capture_forward_mode,
bs=bs,
num_token_non_padded=len(forward_batch.input_ids),
spec_info=forward_batch.spec_info,
)
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
forward_batch.spec_info.custom_mask = self.custom_mask
......
......@@ -352,7 +352,9 @@ class ForwardBatch:
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret)
TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)
return ret
# Override the positions with spec_info
......@@ -397,7 +399,9 @@ class ForwardBatch:
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
TboForwardBatchPreparer.prepare(ret)
TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)
return ret
......
......@@ -1039,7 +1039,7 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.enable_two_batch_overlap:
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
else:
self.attn_backend = self._get_attention_backend()
......
......@@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_deepseek_blog_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
elif (
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
):
return _compute_moe_deepseek_blog_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
......@@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_qwen3_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
elif (
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
):
return _compute_moe_qwen3_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
......
import dataclasses
import logging
from typing import Dict, List, Optional, Sequence
from dataclasses import replace
from typing import Dict, List, Optional, Sequence, Union
import torch
......@@ -16,6 +17,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
......@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
# -------------------------------- Compute Basic Info ---------------------------------------
def get_token_num_per_seq(
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if forward_mode.is_target_verify():
return spec_info.draft_token_num
elif forward_mode.is_decode():
return 1
elif forward_mode.is_idle():
return 0
else:
# For extend, we should not use `token_num_per_seq`.
return None
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def compute_split_seq_index(
forward_mode: "ForwardMode",
num_tokens: int,
extend_lens: Optional[Sequence[int]],
token_num_per_seq: Optional[int],
) -> Optional[int]:
if forward_mode.is_extend():
if forward_mode == ForwardMode.EXTEND:
assert extend_lens is not None
return _split_array_by_half_sum(extend_lens)
elif forward_mode.is_decode():
return num_tokens // 2
elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None
return (num_tokens // token_num_per_seq) // 2
elif forward_mode.is_idle():
assert num_tokens == 0
return 0
......@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
return best_index
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
if seq_index == 0:
return 0
offset = 0
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
for i in range(max_seq_len):
offset += (
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
) * spec_info.draft_token_num
return offset
def split_spec_info(
spec_info: Optional[EagleVerifyInput],
start_seq_index: int,
end_seq_index: int,
start_token_index: int,
end_token_index: int,
):
if spec_info is None:
return None
if spec_info.draft_token is not None:
draft_token = spec_info.draft_token[start_token_index:end_token_index]
else:
draft_token = None
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
custom_mask_end = spec_info.custom_mask.shape[0]
else:
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
if custom_mask_end > custom_mask_start:
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
else:
custom_mask = spec_info.custom_mask
else:
custom_mask = spec_info.custom_mask
if spec_info.positions is not None:
positions = spec_info.positions[start_token_index:end_token_index]
else:
positions = None
if spec_info.retrive_index is not None:
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
else:
retrive_index = None
if spec_info.retrive_next_token is not None:
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
else:
retrive_next_token = None
if spec_info.retrive_next_sibling is not None:
retrive_next_sibling = spec_info.retrive_next_sibling[
start_seq_index:end_seq_index
]
else:
retrive_next_sibling = None
if spec_info.retrive_cum_len is not None:
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
else:
retrive_cum_len = None
if spec_info.seq_lens_cpu is not None:
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
else:
seq_lens_cpu = None
if seq_lens_cpu is not None:
seq_lens_sum = seq_lens_cpu.sum()
else:
seq_lens_sum = None
output_spec_info = replace(
spec_info,
custom_mask=custom_mask,
draft_token=draft_token,
positions=positions,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=retrive_cum_len,
seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=seq_lens_sum,
)
return output_spec_info
def compute_split_token_index(
split_seq_index: int,
forward_mode: "ForwardMode",
extend_seq_lens: Optional[Sequence[int]],
token_num_per_seq: Optional[int],
) -> int:
if forward_mode.is_extend():
if forward_mode == ForwardMode.EXTEND:
assert extend_seq_lens is not None
return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_decode():
return split_seq_index
elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None
return split_seq_index * token_num_per_seq
elif forward_mode.is_idle():
assert split_seq_index == 0
return 0
......@@ -83,19 +189,25 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
tbo_split_seq_index = compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=cuda_graph_num_tokens,
extend_lens=None,
token_num_per_seq=token_num_per_seq,
)
tbo_split_token_index = compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
token_num_per_seq=token_num_per_seq,
)
return tbo_split_seq_index, tbo_split_token_index
......@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
return
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
)
batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode,
num_tokens=num_tokens,
extend_lens=None,
token_num_per_seq=token_num_per_seq,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
......@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
)
def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
self,
forward_mode: ForwardMode,
bs: int,
num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
# TODO support bs!=num_tokens
cuda_graph_num_tokens=bs,
cuda_graph_num_tokens=bs * token_num_per_seq,
spec_info=spec_info,
)
)
......@@ -154,14 +277,29 @@ class TboDPAttentionPreparer:
self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None:
token_num_per_seq = get_token_num_per_seq(
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
)
if (
local_batch.forward_mode.is_target_verify()
or local_batch.forward_mode.is_decode()
):
num_tokens = local_batch.batch_size() * token_num_per_seq
else:
num_tokens = local_batch.extend_num_tokens
self.local_tbo_split_seq_index = compute_split_seq_index(
forward_mode=local_batch.forward_mode,
num_tokens=local_batch.input_ids.shape[0],
num_tokens=num_tokens,
extend_lens=local_batch.extend_lens,
token_num_per_seq=token_num_per_seq,
)
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
local_batch.forward_mode.is_extend()
(
local_batch.forward_mode.is_extend()
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.low_latency)
)
......@@ -218,8 +356,8 @@ class TboDPAttentionPreparer:
class TboForwardBatchPreparer:
@classmethod
def prepare(cls, batch: ForwardBatch):
if batch.tbo_split_seq_index is None:
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
if batch.tbo_split_seq_index is None or is_draft_worker:
return
tbo_children_num_token_non_padded = (
......@@ -242,7 +380,9 @@ class TboForwardBatchPreparer:
f"TboForwardBatchPreparer.prepare "
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
f"tbo_split_token_index={tbo_split_token_index} "
f"extend_seq_lens={batch.extend_seq_lens_cpu}"
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
f"bs={batch.batch_size} "
f"forward_mode={batch.forward_mode}"
)
assert isinstance(batch.attn_backend, TboAttnBackend)
......@@ -286,6 +426,9 @@ class TboForwardBatchPreparer:
output_attn_backend: AttentionBackend,
out_num_token_non_padded: torch.Tensor,
):
assert (
end_token_index >= start_token_index
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size
......@@ -317,11 +460,30 @@ class TboForwardBatchPreparer:
old_value = getattr(batch, key)
if old_value is None:
continue
elif batch.forward_mode.is_target_verify() and (
key == "extend_seq_lens"
or key == "extend_prefix_lens"
or key == "extend_start_loc"
or key == "extend_prefix_lens_cpu"
or key == "extend_seq_lens_cpu"
or key == "extend_logprob_start_lens_cpu"
):
output_dict[key] = None
continue
assert (
len(old_value) == num_seqs
), f"{key=} {old_value=} {num_seqs=} {batch=}"
output_dict[key] = old_value[start_seq_index:end_seq_index]
spec_info = getattr(batch, "spec_info")
output_spec_info = split_spec_info(
spec_info=spec_info,
start_token_index=start_token_index,
end_token_index=end_token_index,
start_seq_index=start_seq_index,
end_seq_index=end_seq_index,
)
output_dict["spec_info"] = output_spec_info
for key in [
"forward_mode",
"return_logprob",
......@@ -329,18 +491,17 @@ class TboForwardBatchPreparer:
"token_to_kv_pool",
"can_run_dp_cuda_graph",
"global_forward_mode",
"spec_info",
"spec_algorithm",
"capture_hidden_mode",
"padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care
]:
output_dict[key] = getattr(batch, key)
assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens
), f"{batch=}"
if not batch.forward_mode.is_target_verify():
assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens
), f"{batch=}"
extend_num_tokens = _compute_extend_num_tokens(
output_dict["input_ids"], output_dict["forward_mode"]
)
......@@ -419,18 +580,26 @@ class TboForwardBatchPreparer:
@classmethod
def _compute_split_token_index(cls, batch: ForwardBatch):
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
)
return compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
token_num_per_seq=token_num_per_seq,
)
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend():
return input_ids.shape[0]
elif forward_mode.is_decode() or forward_mode.is_idle():
if (
forward_mode.is_decode()
or forward_mode.is_idle()
or forward_mode.is_target_verify()
):
return None
elif forward_mode.is_extend():
return input_ids.shape[0]
raise NotImplementedError
......
......@@ -137,5 +137,86 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5)
# TODO: enable this test later
# class TestDPAttentionDP2TP2DeepseekV3MTPTBO(CustomTestCase):
# @classmethod
# def setUpClass(cls):
# import os
# # print debug log for tbo
# os.environ["SGLANG_TBO_DEBUG"] = "1"
# cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
# cls.base_url = DEFAULT_URL_FOR_TEST
# other_args = [
# "--trust-remote-code",
# "--disable-radix",
# "--speculative-algorithm",
# "EAGLE",
# "--speculative-num-steps",
# "2",
# "--speculative-eagle-topk",
# "4",
# "--speculative-num-draft-tokens",
# "4",
# "--speculative-draft",
# DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
# "--tp-size",
# "2",
# "--enable-dp-attention",
# "--dp-size",
# "2",
# "--enable-two-batch-overlap",
# "--enable-deepep-moe",
# "--deepep-mode",
# "low_latency",
# "--chunked-prefill-size",
# "256",
# "--cuda-graph-max-bs",
# "32",
# "--max-running-requests",
# "32",
# ]
# if not is_in_amd_ci():
# other_args += ["--mem-frac", "0.7"]
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args=other_args,
# )
# @classmethod
# def tearDownClass(cls):
# kill_process_tree(cls.process.pid)
# def test_gsm8k(self):
# requests.get(self.base_url + "/flush_cache")
# args = SimpleNamespace(
# num_shots=5,
# data_path=None,
# num_questions=200,
# max_new_tokens=512,
# parallel=128,
# host="http://127.0.0.1",
# port=int(self.base_url.split(":")[-1]),
# )
# metrics = run_eval_few_shot_gsm8k(args)
# print(metrics)
# self.assertGreater(metrics["accuracy"], 0.60)
# server_info = requests.get(self.base_url + "/get_server_info")
# avg_spec_accept_length = server_info.json()["internal_states"][0][
# "avg_spec_accept_length"
# ]
# print(
# f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
# f"accuracy={metrics['accuracy']=:.3f}\n"
# f"{avg_spec_accept_length=:.3f}\n"
# )
# self.assertGreater(avg_spec_accept_length, 2.3)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment