Unverified Commit ca47e24f authored by HouseWest's avatar HouseWest Committed by GitHub
Browse files

[Feature] improve TBO: two chunk overlap (#8144)

parent d26ca84f
......@@ -262,6 +262,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False |
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
......
......@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_radix_cache",
"enable_dp_attention",
"enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head",
"moe_a2a_backend",
"deepep_mode",
......
......@@ -420,16 +420,12 @@ class ForwardBatch:
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
if support_triton(model_runner.server_args.attention_backend):
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
ret.extend_seq_lens,
ret.extend_num_tokens,
)
else:
positions, ret.extend_start_loc = compute_position_torch(
ret.extend_prefix_lens, ret.extend_seq_lens
)
positions, ret.extend_start_loc = compute_position(
model_runner.server_args.attention_backend,
ret.extend_prefix_lens,
ret.extend_seq_lens,
ret.extend_num_tokens,
)
if ret.positions is None:
ret.positions = positions
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
......@@ -882,6 +878,25 @@ class PPProxyTensors:
return f"PPProxyTensors(tensors={self.tensors})"
def compute_position(
attn_backend: str,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
extend_seq_lens_sum: int,
):
if support_triton(attn_backend):
positions, extend_start_loc = compute_position_triton(
extend_prefix_lens,
extend_seq_lens,
extend_seq_lens_sum,
)
else:
positions, extend_start_loc = compute_position_torch(
extend_prefix_lens, extend_seq_lens
)
return positions, extend_start_loc
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):
......
......@@ -229,6 +229,7 @@ class ServerArgs:
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False
tbo_token_distribution_threshold: float = 0.48
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
torchao_config: str = ""
......@@ -1689,6 +1690,12 @@ class ServerArgs:
action="store_true",
help="Enabling two micro batches to overlap.",
)
parser.add_argument(
"--tbo-token-distribution-threshold",
type=float,
default=ServerArgs.tbo_token_distribution_threshold,
help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
......
from __future__ import annotations
import copy
import dataclasses
import logging
from dataclasses import replace
......@@ -17,7 +18,11 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
compute_position,
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
......@@ -58,7 +63,7 @@ def compute_split_seq_index(
) -> Optional[int]:
if forward_mode == ForwardMode.EXTEND:
assert extend_lens is not None
return _split_array_by_half_sum(extend_lens)
return _split_extend_seqs(extend_lens)
elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None
return (num_tokens // token_num_per_seq) // 2
......@@ -69,7 +74,43 @@ def compute_split_seq_index(
raise NotImplementedError()
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
if extend_lens is None:
return False
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold
)
def _split_extend_seqs(arr: Sequence[int]) -> int:
if _is_two_chunk_split_enabled(arr):
return _split_array_by_cum_less_than_half(arr)
return _split_array_by_balanced_sum(arr)
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
left_sum = 0
overall_sum = sum(arr)
half_sum = overall_sum // 2
chosen_index = 0
for i in range(len(arr)):
left_sum += arr[i]
if left_sum > half_sum:
chosen_index = i
break
return chosen_index
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
overall_sum = sum(arr)
left_sum = 0
min_diff = float("inf")
......@@ -88,6 +129,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
return best_index
def _update_device_and_sum_field_from_cpu_field(
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
):
cpu_value = getattr(batch, cpu_field, None)
old_device_value = getattr(batch, device_field, None)
if (
cpu_value is None
or old_device_value is None
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
):
return
new_device_value = (
cpu_value
if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True)
setattr(batch, device_field, new_device_value)
if sum_field is not None:
sum_value = (
cpu_value.sum().item()
if isinstance(cpu_value, torch.Tensor)
else sum(cpu_value)
)
setattr(batch, sum_field, sum_value)
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
if seq_index == 0:
return 0
......@@ -181,6 +250,8 @@ def compute_split_token_index(
) -> int:
if forward_mode == ForwardMode.EXTEND:
assert extend_seq_lens is not None
if _is_two_chunk_split_enabled(extend_seq_lens):
return sum(extend_seq_lens) // 2
return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None
......@@ -388,9 +459,15 @@ class TboForwardBatchPreparer:
tbo_split_token_index = cls._compute_split_token_index(batch)
is_enable_two_chunk = (
batch.forward_mode == ForwardMode.EXTEND
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
)
if _tbo_debug:
logger.info(
f"TboForwardBatchPreparer.prepare "
f"is_enable_two_chunk={is_enable_two_chunk} "
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
f"tbo_split_token_index={tbo_split_token_index} "
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
......@@ -410,7 +487,11 @@ class TboForwardBatchPreparer:
start_token_index=0,
end_token_index=tbo_split_token_index,
start_seq_index=0,
end_seq_index=batch.tbo_split_seq_index,
end_seq_index=(
batch.tbo_split_seq_index + 1
if is_enable_two_chunk
else batch.tbo_split_seq_index
),
output_attn_backend=attn_backend_child_a,
out_num_token_non_padded=out_num_token_non_padded_a,
)
......@@ -424,9 +505,79 @@ class TboForwardBatchPreparer:
out_num_token_non_padded=out_num_token_non_padded_b,
)
if is_enable_two_chunk:
cls.derive_fields_related_to_seq_len_for_two_chunk(
batch,
child_a=child_a,
child_b=child_b,
tbo_split_seq_index=batch.tbo_split_seq_index,
)
assert batch.tbo_children is None
batch.tbo_children = [child_a, child_b]
@classmethod
def derive_fields_related_to_seq_len_for_two_chunk(
cls,
batch: ForwardBatch,
*,
child_a: ForwardBatch,
child_b: ForwardBatch,
tbo_split_seq_index: int,
):
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
half_seq_lens_sum = overall_seq_lens_sum // 2
left_last_seq_token_num = half_seq_lens_sum - sum(
extend_seq_lens_cpu[:tbo_split_seq_index]
)
right_first_seq_token_num = (
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
)
# making deepcopy to be extra safe
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
for child in [child_a, child_b]:
_update_device_and_sum_field_from_cpu_field(
batch=child,
cpu_field="extend_seq_lens_cpu",
device_field="extend_seq_lens",
sum_field="extend_num_tokens",
)
assert (
child_a.extend_num_tokens == half_seq_lens_sum
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
child_a.seq_lens_cpu[-1] = (
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
)
_update_device_and_sum_field_from_cpu_field(
batch=child_a,
cpu_field="seq_lens_cpu",
device_field="seq_lens",
sum_field="seq_lens_sum",
)
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
_update_device_and_sum_field_from_cpu_field(
batch=child_b,
cpu_field="extend_prefix_lens_cpu",
device_field="extend_prefix_lens",
sum_field=None,
)
_, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"],
child_b.extend_prefix_lens,
child_b.extend_seq_lens,
child_b.extend_num_tokens,
)
@classmethod
def filter_batch(
cls,
......
......@@ -5,7 +5,10 @@ from types import SimpleNamespace
import requests
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.two_batch_overlap import compute_split_seq_index
from sglang.srt.two_batch_overlap import (
compute_split_seq_index,
compute_split_token_index,
)
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
......@@ -73,35 +76,46 @@ class TestTwoBatchOverlap(unittest.TestCase):
class TestTwoBatchOverlapUnitTest(unittest.TestCase):
# TODO change tests when having 6328
def test_compute_split_seq_index(self):
def test_compute_split_seq_and_token_index(self):
for num_tokens, expect in [
(0, 0),
(100, 50),
(99, 49),
]:
actual = compute_split_seq_index(
forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None
forward_mode=ForwardMode.DECODE,
num_tokens=num_tokens,
extend_lens=None,
token_num_per_seq=1,
)
self.assertEqual(actual, expect)
for extend_lens, expect in [
([], 0),
([42], 0),
([42, 999], 1),
([999, 42], 1),
([4096, 4096, 4096, 4096], 2),
([4095, 4096, 4096, 4096, 1], 2),
([1, 4095, 4096, 4096, 4096], 3),
([4097, 4096, 4096, 4095, 1], 2),
([1, 1, 1, 1, 99999], 4),
([99999, 1, 1, 1, 1], 1),
([], (0, 0)),
([42], (0, 21)),
([42, 999], (1, 520)),
([999, 42], (0, 520)),
([498, 502], (1, 498)),
([4096, 4096, 4096, 4096], (2, 8192)),
([4095, 4096, 4096, 4096, 1], (2, 8191)),
([1, 4095, 4096, 4096, 4096], (3, 8192)),
([4097, 4096, 4096, 4095, 1], (2, 8193)),
([1, 1, 1, 1, 99999], (4, 50001)),
([99999, 1, 1, 1, 1], (0, 50001)),
]:
actual = compute_split_seq_index(
actual_seq_idx = compute_split_seq_index(
forward_mode=ForwardMode.EXTEND,
num_tokens=None,
extend_lens=extend_lens,
token_num_per_seq=None,
)
actual_token_idx = compute_split_token_index(
split_seq_index=actual_seq_idx,
forward_mode=ForwardMode.EXTEND,
extend_seq_lens=extend_lens,
token_num_per_seq=None,
)
actual = (actual_seq_idx, actual_token_idx)
print(f"{extend_lens=} {expect=} {actual=}")
self.assertEqual(actual, expect)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment