Unverified Commit 41650b0d authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

feat: support compatibility between MTP and two-batch-overlap (#7225)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 1b951620
...@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend): ...@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
replay_seq_lens_sum: int = None, replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None, replay_seq_lens_cpu: Optional[torch.Tensor] = None,
): ):
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
if fn_name == "init_forward_metadata_capture_cuda_graph": if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently" assert (
num_tokens = bs capture_num_tokens == bs * token_num_per_seq
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
num_tokens = bs * token_num_per_seq
tbo_split_seq_index, tbo_split_token_index = ( tbo_split_seq_index, tbo_split_token_index = (
two_batch_overlap.compute_split_indices_for_cuda_graph_replay( two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode, forward_mode=forward_mode,
cuda_graph_num_tokens=num_tokens, cuda_graph_num_tokens=num_tokens,
spec_info=spec_info,
) )
) )
num_tokens_child_left = tbo_split_token_index num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left bs_child_left = tbo_split_seq_index
bs_child_right = num_tokens_child_right bs_child_right = bs - bs_child_left
assert ( assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0 num_tokens_child_left > 0 and num_tokens_child_right > 0
...@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split( ...@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[EagleVerifyInput],
# capture args # capture args
capture_num_tokens: int = None, capture_num_tokens: int = None,
# replay args # replay args
replay_seq_lens_sum: int = None, replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None, replay_seq_lens_cpu: Optional[torch.Tensor] = None,
): ):
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
assert encoder_lens is None, "encoder_lens is not supported yet" assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet" if spec_info is not None:
output_spec_info = two_batch_overlap.split_spec_info(
spec_info=spec_info,
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
start_token_index=(
seq_slice.start * token_num_per_seq
if seq_slice.start is not None
else 0
),
end_token_index=(
seq_slice.stop * token_num_per_seq
if seq_slice.stop is not None
else bs * token_num_per_seq
),
)
else:
output_spec_info = None
ans = dict( ans = dict(
bs=output_bs, bs=output_bs,
req_pool_indices=req_pool_indices[seq_slice], req_pool_indices=req_pool_indices[seq_slice],
...@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split( ...@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
forward_mode=forward_mode, forward_mode=forward_mode,
# ignore # ignore
encoder_lens=None, encoder_lens=None,
spec_info=None, spec_info=output_spec_info,
) )
if fn_name == "init_forward_metadata_capture_cuda_graph": if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently" assert (
capture_num_tokens == bs * token_num_per_seq
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
ans.update( ans.update(
dict( dict(
num_tokens=output_bs, num_tokens=output_bs * token_num_per_seq,
) )
) )
elif fn_name == "init_forward_metadata_replay_cuda_graph": elif fn_name == "init_forward_metadata_replay_cuda_graph":
......
...@@ -679,6 +679,7 @@ class CudaGraphRunner: ...@@ -679,6 +679,7 @@ class CudaGraphRunner:
forward_mode=self.capture_forward_mode, forward_mode=self.capture_forward_mode,
bs=bs, bs=bs,
num_token_non_padded=len(forward_batch.input_ids), num_token_non_padded=len(forward_batch.input_ids),
spec_info=forward_batch.spec_info,
) )
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
forward_batch.spec_info.custom_mask = self.custom_mask forward_batch.spec_info.custom_mask = self.custom_mask
......
...@@ -352,7 +352,9 @@ class ForwardBatch: ...@@ -352,7 +352,9 @@ class ForwardBatch:
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device) ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret) TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)
return ret return ret
# Override the positions with spec_info # Override the positions with spec_info
...@@ -397,7 +399,9 @@ class ForwardBatch: ...@@ -397,7 +399,9 @@ class ForwardBatch:
if model_runner.server_args.lora_paths is not None: if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret) model_runner.lora_manager.prepare_lora_batch(ret)
TboForwardBatchPreparer.prepare(ret) TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)
return ret return ret
......
...@@ -1039,7 +1039,7 @@ class ModelRunner: ...@@ -1039,7 +1039,7 @@ class ModelRunner:
def init_attention_backend(self): def init_attention_backend(self):
"""Init attention kernel backend.""" """Init attention kernel backend."""
if self.server_args.enable_two_batch_overlap: if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend) self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
else: else:
self.attn_backend = self._get_attention_backend() self.attn_backend = self._get_attention_backend()
......
...@@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo( ...@@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
assert layer.is_layer_sparse, "dense layer TBO not yet implemented" assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
return _compute_moe_deepseek_blog_prefill(layer) return _compute_moe_deepseek_blog_prefill(layer)
elif forward_mode == ForwardMode.DECODE: elif (
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
):
return _compute_moe_deepseek_blog_decode(layer) return _compute_moe_deepseek_blog_decode(layer)
else: else:
raise NotImplementedError(f"Unsupported {forward_mode=}") raise NotImplementedError(f"Unsupported {forward_mode=}")
...@@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo( ...@@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers" assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
return _compute_moe_qwen3_prefill(layer) return _compute_moe_qwen3_prefill(layer)
elif forward_mode == ForwardMode.DECODE: elif (
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
):
return _compute_moe_qwen3_decode(layer) return _compute_moe_qwen3_decode(layer)
else: else:
raise NotImplementedError(f"Unsupported {forward_mode=}") raise NotImplementedError(f"Unsupported {forward_mode=}")
......
import dataclasses import dataclasses
import logging import logging
from typing import Dict, List, Optional, Sequence from dataclasses import replace
from typing import Dict, List, Optional, Sequence, Union
import torch import torch
...@@ -16,6 +17,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -16,6 +17,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
...@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__) ...@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
# -------------------------------- Compute Basic Info --------------------------------------- # -------------------------------- Compute Basic Info ---------------------------------------
def get_token_num_per_seq(
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if forward_mode.is_target_verify():
return spec_info.draft_token_num
elif forward_mode.is_decode():
return 1
elif forward_mode.is_idle():
return 0
else:
# For extend, we should not use `token_num_per_seq`.
return None
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down # TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def compute_split_seq_index( def compute_split_seq_index(
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
num_tokens: int, num_tokens: int,
extend_lens: Optional[Sequence[int]], extend_lens: Optional[Sequence[int]],
token_num_per_seq: Optional[int],
) -> Optional[int]: ) -> Optional[int]:
if forward_mode.is_extend(): if forward_mode == ForwardMode.EXTEND:
assert extend_lens is not None assert extend_lens is not None
return _split_array_by_half_sum(extend_lens) return _split_array_by_half_sum(extend_lens)
elif forward_mode.is_decode(): elif forward_mode.is_target_verify() or forward_mode.is_decode():
return num_tokens // 2 assert token_num_per_seq is not None
return (num_tokens // token_num_per_seq) // 2
elif forward_mode.is_idle(): elif forward_mode.is_idle():
assert num_tokens == 0 assert num_tokens == 0
return 0 return 0
...@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int: ...@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
return best_index return best_index
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
if seq_index == 0:
return 0
offset = 0
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
for i in range(max_seq_len):
offset += (
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
) * spec_info.draft_token_num
return offset
def split_spec_info(
spec_info: Optional[EagleVerifyInput],
start_seq_index: int,
end_seq_index: int,
start_token_index: int,
end_token_index: int,
):
if spec_info is None:
return None
if spec_info.draft_token is not None:
draft_token = spec_info.draft_token[start_token_index:end_token_index]
else:
draft_token = None
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
custom_mask_end = spec_info.custom_mask.shape[0]
else:
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
if custom_mask_end > custom_mask_start:
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
else:
custom_mask = spec_info.custom_mask
else:
custom_mask = spec_info.custom_mask
if spec_info.positions is not None:
positions = spec_info.positions[start_token_index:end_token_index]
else:
positions = None
if spec_info.retrive_index is not None:
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
else:
retrive_index = None
if spec_info.retrive_next_token is not None:
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
else:
retrive_next_token = None
if spec_info.retrive_next_sibling is not None:
retrive_next_sibling = spec_info.retrive_next_sibling[
start_seq_index:end_seq_index
]
else:
retrive_next_sibling = None
if spec_info.retrive_cum_len is not None:
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
else:
retrive_cum_len = None
if spec_info.seq_lens_cpu is not None:
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
else:
seq_lens_cpu = None
if seq_lens_cpu is not None:
seq_lens_sum = seq_lens_cpu.sum()
else:
seq_lens_sum = None
output_spec_info = replace(
spec_info,
custom_mask=custom_mask,
draft_token=draft_token,
positions=positions,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=retrive_cum_len,
seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=seq_lens_sum,
)
return output_spec_info
def compute_split_token_index( def compute_split_token_index(
split_seq_index: int, split_seq_index: int,
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
extend_seq_lens: Optional[Sequence[int]], extend_seq_lens: Optional[Sequence[int]],
token_num_per_seq: Optional[int],
) -> int: ) -> int:
if forward_mode.is_extend(): if forward_mode == ForwardMode.EXTEND:
assert extend_seq_lens is not None assert extend_seq_lens is not None
return sum(extend_seq_lens[:split_seq_index]) return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_decode(): elif forward_mode.is_target_verify() or forward_mode.is_decode():
return split_seq_index assert token_num_per_seq is not None
return split_seq_index * token_num_per_seq
elif forward_mode.is_idle(): elif forward_mode.is_idle():
assert split_seq_index == 0 assert split_seq_index == 0
return 0 return 0
...@@ -83,19 +189,25 @@ def compute_split_token_index( ...@@ -83,19 +189,25 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay( def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode, forward_mode: ForwardMode,
cuda_graph_num_tokens: int, cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
forward_mode_for_tbo_split = ( forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
) )
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
tbo_split_seq_index = compute_split_seq_index( tbo_split_seq_index = compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split, forward_mode=forward_mode_for_tbo_split,
num_tokens=cuda_graph_num_tokens, num_tokens=cuda_graph_num_tokens,
extend_lens=None, extend_lens=None,
token_num_per_seq=token_num_per_seq,
) )
tbo_split_token_index = compute_split_token_index( tbo_split_token_index = compute_split_token_index(
split_seq_index=tbo_split_seq_index, split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split, forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None, extend_seq_lens=None,
token_num_per_seq=token_num_per_seq,
) )
return tbo_split_seq_index, tbo_split_token_index return tbo_split_seq_index, tbo_split_token_index
...@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin: ...@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]: if not global_server_args_dict["enable_two_batch_overlap"]:
return return
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
)
batch.tbo_split_seq_index = compute_split_seq_index( batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
num_tokens=num_tokens, num_tokens=num_tokens,
extend_lens=None, extend_lens=None,
token_num_per_seq=token_num_per_seq,
) )
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
...@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin: ...@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
) )
def replay_prepare( def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int self,
forward_mode: ForwardMode,
bs: int,
num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
tbo_split_seq_index, tbo_split_token_index = ( tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay( compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode, forward_mode=forward_mode,
# TODO support bs!=num_tokens cuda_graph_num_tokens=bs * token_num_per_seq,
cuda_graph_num_tokens=bs, spec_info=spec_info,
) )
) )
...@@ -154,14 +277,29 @@ class TboDPAttentionPreparer: ...@@ -154,14 +277,29 @@ class TboDPAttentionPreparer:
self.enable_two_batch_overlap = enable_two_batch_overlap self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None: if local_batch is not None:
token_num_per_seq = get_token_num_per_seq(
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
)
if (
local_batch.forward_mode.is_target_verify()
or local_batch.forward_mode.is_decode()
):
num_tokens = local_batch.batch_size() * token_num_per_seq
else:
num_tokens = local_batch.extend_num_tokens
self.local_tbo_split_seq_index = compute_split_seq_index( self.local_tbo_split_seq_index = compute_split_seq_index(
forward_mode=local_batch.forward_mode, forward_mode=local_batch.forward_mode,
num_tokens=local_batch.input_ids.shape[0], num_tokens=num_tokens,
extend_lens=local_batch.extend_lens, extend_lens=local_batch.extend_lens,
token_num_per_seq=token_num_per_seq,
) )
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode) resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not ( local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
(
local_batch.forward_mode.is_extend() local_batch.forward_mode.is_extend()
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.low_latency) and (resolved_deepep_mode == DeepEPMode.low_latency)
) )
...@@ -218,8 +356,8 @@ class TboDPAttentionPreparer: ...@@ -218,8 +356,8 @@ class TboDPAttentionPreparer:
class TboForwardBatchPreparer: class TboForwardBatchPreparer:
@classmethod @classmethod
def prepare(cls, batch: ForwardBatch): def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
if batch.tbo_split_seq_index is None: if batch.tbo_split_seq_index is None or is_draft_worker:
return return
tbo_children_num_token_non_padded = ( tbo_children_num_token_non_padded = (
...@@ -242,7 +380,9 @@ class TboForwardBatchPreparer: ...@@ -242,7 +380,9 @@ class TboForwardBatchPreparer:
f"TboForwardBatchPreparer.prepare " f"TboForwardBatchPreparer.prepare "
f"tbo_split_seq_index={batch.tbo_split_seq_index} " f"tbo_split_seq_index={batch.tbo_split_seq_index} "
f"tbo_split_token_index={tbo_split_token_index} " f"tbo_split_token_index={tbo_split_token_index} "
f"extend_seq_lens={batch.extend_seq_lens_cpu}" f"extend_seq_lens={batch.extend_seq_lens_cpu} "
f"bs={batch.batch_size} "
f"forward_mode={batch.forward_mode}"
) )
assert isinstance(batch.attn_backend, TboAttnBackend) assert isinstance(batch.attn_backend, TboAttnBackend)
...@@ -286,6 +426,9 @@ class TboForwardBatchPreparer: ...@@ -286,6 +426,9 @@ class TboForwardBatchPreparer:
output_attn_backend: AttentionBackend, output_attn_backend: AttentionBackend,
out_num_token_non_padded: torch.Tensor, out_num_token_non_padded: torch.Tensor,
): ):
assert (
end_token_index >= start_token_index
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
num_tokens = batch.input_ids.shape[0] num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size num_seqs = batch.batch_size
...@@ -317,11 +460,30 @@ class TboForwardBatchPreparer: ...@@ -317,11 +460,30 @@ class TboForwardBatchPreparer:
old_value = getattr(batch, key) old_value = getattr(batch, key)
if old_value is None: if old_value is None:
continue continue
elif batch.forward_mode.is_target_verify() and (
key == "extend_seq_lens"
or key == "extend_prefix_lens"
or key == "extend_start_loc"
or key == "extend_prefix_lens_cpu"
or key == "extend_seq_lens_cpu"
or key == "extend_logprob_start_lens_cpu"
):
output_dict[key] = None
continue
assert ( assert (
len(old_value) == num_seqs len(old_value) == num_seqs
), f"{key=} {old_value=} {num_seqs=} {batch=}" ), f"{key=} {old_value=} {num_seqs=} {batch=}"
output_dict[key] = old_value[start_seq_index:end_seq_index] output_dict[key] = old_value[start_seq_index:end_seq_index]
spec_info = getattr(batch, "spec_info")
output_spec_info = split_spec_info(
spec_info=spec_info,
start_token_index=start_token_index,
end_token_index=end_token_index,
start_seq_index=start_seq_index,
end_seq_index=end_seq_index,
)
output_dict["spec_info"] = output_spec_info
for key in [ for key in [
"forward_mode", "forward_mode",
"return_logprob", "return_logprob",
...@@ -329,14 +491,13 @@ class TboForwardBatchPreparer: ...@@ -329,14 +491,13 @@ class TboForwardBatchPreparer:
"token_to_kv_pool", "token_to_kv_pool",
"can_run_dp_cuda_graph", "can_run_dp_cuda_graph",
"global_forward_mode", "global_forward_mode",
"spec_info",
"spec_algorithm", "spec_algorithm",
"capture_hidden_mode", "capture_hidden_mode",
"padded_static_len", "padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care "mrope_positions", # only used by qwen2-vl, thus not care
]: ]:
output_dict[key] = getattr(batch, key) output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify():
assert ( assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode) _compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens == batch.extend_num_tokens
...@@ -419,18 +580,26 @@ class TboForwardBatchPreparer: ...@@ -419,18 +580,26 @@ class TboForwardBatchPreparer:
@classmethod @classmethod
def _compute_split_token_index(cls, batch: ForwardBatch): def _compute_split_token_index(cls, batch: ForwardBatch):
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
)
return compute_split_token_index( return compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index, split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu, extend_seq_lens=batch.extend_seq_lens_cpu,
token_num_per_seq=token_num_per_seq,
) )
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend(): if (
return input_ids.shape[0] forward_mode.is_decode()
elif forward_mode.is_decode() or forward_mode.is_idle(): or forward_mode.is_idle()
or forward_mode.is_target_verify()
):
return None return None
elif forward_mode.is_extend():
return input_ids.shape[0]
raise NotImplementedError raise NotImplementedError
......
...@@ -137,5 +137,86 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): ...@@ -137,5 +137,86 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5) self.assertGreater(avg_spec_accept_length, 2.5)
# TODO: enable this test later
# class TestDPAttentionDP2TP2DeepseekV3MTPTBO(CustomTestCase):
# @classmethod
# def setUpClass(cls):
# import os
# # print debug log for tbo
# os.environ["SGLANG_TBO_DEBUG"] = "1"
# cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
# cls.base_url = DEFAULT_URL_FOR_TEST
# other_args = [
# "--trust-remote-code",
# "--disable-radix",
# "--speculative-algorithm",
# "EAGLE",
# "--speculative-num-steps",
# "2",
# "--speculative-eagle-topk",
# "4",
# "--speculative-num-draft-tokens",
# "4",
# "--speculative-draft",
# DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
# "--tp-size",
# "2",
# "--enable-dp-attention",
# "--dp-size",
# "2",
# "--enable-two-batch-overlap",
# "--enable-deepep-moe",
# "--deepep-mode",
# "low_latency",
# "--chunked-prefill-size",
# "256",
# "--cuda-graph-max-bs",
# "32",
# "--max-running-requests",
# "32",
# ]
# if not is_in_amd_ci():
# other_args += ["--mem-frac", "0.7"]
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args=other_args,
# )
# @classmethod
# def tearDownClass(cls):
# kill_process_tree(cls.process.pid)
# def test_gsm8k(self):
# requests.get(self.base_url + "/flush_cache")
# args = SimpleNamespace(
# num_shots=5,
# data_path=None,
# num_questions=200,
# max_new_tokens=512,
# parallel=128,
# host="http://127.0.0.1",
# port=int(self.base_url.split(":")[-1]),
# )
# metrics = run_eval_few_shot_gsm8k(args)
# print(metrics)
# self.assertGreater(metrics["accuracy"], 0.60)
# server_info = requests.get(self.base_url + "/get_server_info")
# avg_spec_accept_length = server_info.json()["internal_states"][0][
# "avg_spec_accept_length"
# ]
# print(
# f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
# f"accuracy={metrics['accuracy']=:.3f}\n"
# f"{avg_spec_accept_length=:.3f}\n"
# )
# self.assertGreater(avg_spec_accept_length, 2.3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment