Unverified Commit ca47e24f authored by HouseWest's avatar HouseWest Committed by GitHub
Browse files

[Feature] improve TBO: two chunk overlap (#8144)

parent d26ca84f
...@@ -262,6 +262,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -262,6 +262,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False | | `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False |
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False | | `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False | | `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False | | `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 | | `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | | | `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
......
...@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_radix_cache", "disable_radix_cache",
"enable_dp_attention", "enable_dp_attention",
"enable_two_batch_overlap", "enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head", "enable_dp_lm_head",
"moe_a2a_backend", "moe_a2a_backend",
"deepep_mode", "deepep_mode",
......
...@@ -420,16 +420,12 @@ class ForwardBatch: ...@@ -420,16 +420,12 @@ class ForwardBatch:
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
if support_triton(model_runner.server_args.attention_backend): positions, ret.extend_start_loc = compute_position(
positions, ret.extend_start_loc = compute_position_triton( model_runner.server_args.attention_backend,
ret.extend_prefix_lens, ret.extend_prefix_lens,
ret.extend_seq_lens, ret.extend_seq_lens,
ret.extend_num_tokens, ret.extend_num_tokens,
) )
else:
positions, ret.extend_start_loc = compute_position_torch(
ret.extend_prefix_lens, ret.extend_seq_lens
)
if ret.positions is None: if ret.positions is None:
ret.positions = positions ret.positions = positions
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
...@@ -882,6 +878,25 @@ class PPProxyTensors: ...@@ -882,6 +878,25 @@ class PPProxyTensors:
return f"PPProxyTensors(tensors={self.tensors})" return f"PPProxyTensors(tensors={self.tensors})"
def compute_position(
attn_backend: str,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
extend_seq_lens_sum: int,
):
if support_triton(attn_backend):
positions, extend_start_loc = compute_position_triton(
extend_prefix_lens,
extend_seq_lens,
extend_seq_lens_sum,
)
else:
positions, extend_start_loc = compute_position_torch(
extend_prefix_lens, extend_seq_lens
)
return positions, extend_start_loc
def compute_position_triton( def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
): ):
......
...@@ -229,6 +229,7 @@ class ServerArgs: ...@@ -229,6 +229,7 @@ class ServerArgs:
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_dp_lm_head: bool = False enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False enable_two_batch_overlap: bool = False
tbo_token_distribution_threshold: float = 0.48
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
torchao_config: str = "" torchao_config: str = ""
...@@ -1689,6 +1690,12 @@ class ServerArgs: ...@@ -1689,6 +1690,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enabling two micro batches to overlap.", help="Enabling two micro batches to overlap.",
) )
parser.add_argument(
"--tbo-token-distribution-threshold",
type=float,
default=ServerArgs.tbo_token_distribution_threshold,
help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
)
parser.add_argument( parser.add_argument(
"--enable-torch-compile", "--enable-torch-compile",
action="store_true", action="store_true",
......
from __future__ import annotations from __future__ import annotations
import copy
import dataclasses import dataclasses
import logging import logging
from dataclasses import replace from dataclasses import replace
...@@ -17,7 +18,11 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher ...@@ -17,7 +18,11 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
compute_position,
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -58,7 +63,7 @@ def compute_split_seq_index( ...@@ -58,7 +63,7 @@ def compute_split_seq_index(
) -> Optional[int]: ) -> Optional[int]:
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
assert extend_lens is not None assert extend_lens is not None
return _split_array_by_half_sum(extend_lens) return _split_extend_seqs(extend_lens)
elif forward_mode.is_target_verify() or forward_mode.is_decode(): elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None assert token_num_per_seq is not None
return (num_tokens // token_num_per_seq) // 2 return (num_tokens // token_num_per_seq) // 2
...@@ -69,7 +74,43 @@ def compute_split_seq_index( ...@@ -69,7 +74,43 @@ def compute_split_seq_index(
raise NotImplementedError() raise NotImplementedError()
def _split_array_by_half_sum(arr: Sequence[int]) -> int: def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
if extend_lens is None:
return False
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold
)
def _split_extend_seqs(arr: Sequence[int]) -> int:
if _is_two_chunk_split_enabled(arr):
return _split_array_by_cum_less_than_half(arr)
return _split_array_by_balanced_sum(arr)
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
left_sum = 0
overall_sum = sum(arr)
half_sum = overall_sum // 2
chosen_index = 0
for i in range(len(arr)):
left_sum += arr[i]
if left_sum > half_sum:
chosen_index = i
break
return chosen_index
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
overall_sum = sum(arr) overall_sum = sum(arr)
left_sum = 0 left_sum = 0
min_diff = float("inf") min_diff = float("inf")
...@@ -88,6 +129,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int: ...@@ -88,6 +129,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
return best_index return best_index
def _update_device_and_sum_field_from_cpu_field(
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
):
cpu_value = getattr(batch, cpu_field, None)
old_device_value = getattr(batch, device_field, None)
if (
cpu_value is None
or old_device_value is None
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
):
return
new_device_value = (
cpu_value
if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True)
setattr(batch, device_field, new_device_value)
if sum_field is not None:
sum_value = (
cpu_value.sum().item()
if isinstance(cpu_value, torch.Tensor)
else sum(cpu_value)
)
setattr(batch, sum_field, sum_value)
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int: def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
if seq_index == 0: if seq_index == 0:
return 0 return 0
...@@ -181,6 +250,8 @@ def compute_split_token_index( ...@@ -181,6 +250,8 @@ def compute_split_token_index(
) -> int: ) -> int:
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
assert extend_seq_lens is not None assert extend_seq_lens is not None
if _is_two_chunk_split_enabled(extend_seq_lens):
return sum(extend_seq_lens) // 2
return sum(extend_seq_lens[:split_seq_index]) return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_target_verify() or forward_mode.is_decode(): elif forward_mode.is_target_verify() or forward_mode.is_decode():
assert token_num_per_seq is not None assert token_num_per_seq is not None
...@@ -388,9 +459,15 @@ class TboForwardBatchPreparer: ...@@ -388,9 +459,15 @@ class TboForwardBatchPreparer:
tbo_split_token_index = cls._compute_split_token_index(batch) tbo_split_token_index = cls._compute_split_token_index(batch)
is_enable_two_chunk = (
batch.forward_mode == ForwardMode.EXTEND
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
)
if _tbo_debug: if _tbo_debug:
logger.info( logger.info(
f"TboForwardBatchPreparer.prepare " f"TboForwardBatchPreparer.prepare "
f"is_enable_two_chunk={is_enable_two_chunk} "
f"tbo_split_seq_index={batch.tbo_split_seq_index} " f"tbo_split_seq_index={batch.tbo_split_seq_index} "
f"tbo_split_token_index={tbo_split_token_index} " f"tbo_split_token_index={tbo_split_token_index} "
f"extend_seq_lens={batch.extend_seq_lens_cpu} " f"extend_seq_lens={batch.extend_seq_lens_cpu} "
...@@ -410,7 +487,11 @@ class TboForwardBatchPreparer: ...@@ -410,7 +487,11 @@ class TboForwardBatchPreparer:
start_token_index=0, start_token_index=0,
end_token_index=tbo_split_token_index, end_token_index=tbo_split_token_index,
start_seq_index=0, start_seq_index=0,
end_seq_index=batch.tbo_split_seq_index, end_seq_index=(
batch.tbo_split_seq_index + 1
if is_enable_two_chunk
else batch.tbo_split_seq_index
),
output_attn_backend=attn_backend_child_a, output_attn_backend=attn_backend_child_a,
out_num_token_non_padded=out_num_token_non_padded_a, out_num_token_non_padded=out_num_token_non_padded_a,
) )
...@@ -424,9 +505,79 @@ class TboForwardBatchPreparer: ...@@ -424,9 +505,79 @@ class TboForwardBatchPreparer:
out_num_token_non_padded=out_num_token_non_padded_b, out_num_token_non_padded=out_num_token_non_padded_b,
) )
if is_enable_two_chunk:
cls.derive_fields_related_to_seq_len_for_two_chunk(
batch,
child_a=child_a,
child_b=child_b,
tbo_split_seq_index=batch.tbo_split_seq_index,
)
assert batch.tbo_children is None assert batch.tbo_children is None
batch.tbo_children = [child_a, child_b] batch.tbo_children = [child_a, child_b]
@classmethod
def derive_fields_related_to_seq_len_for_two_chunk(
cls,
batch: ForwardBatch,
*,
child_a: ForwardBatch,
child_b: ForwardBatch,
tbo_split_seq_index: int,
):
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
half_seq_lens_sum = overall_seq_lens_sum // 2
left_last_seq_token_num = half_seq_lens_sum - sum(
extend_seq_lens_cpu[:tbo_split_seq_index]
)
right_first_seq_token_num = (
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
)
# making deepcopy to be extra safe
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
for child in [child_a, child_b]:
_update_device_and_sum_field_from_cpu_field(
batch=child,
cpu_field="extend_seq_lens_cpu",
device_field="extend_seq_lens",
sum_field="extend_num_tokens",
)
assert (
child_a.extend_num_tokens == half_seq_lens_sum
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
child_a.seq_lens_cpu[-1] = (
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
)
_update_device_and_sum_field_from_cpu_field(
batch=child_a,
cpu_field="seq_lens_cpu",
device_field="seq_lens",
sum_field="seq_lens_sum",
)
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
_update_device_and_sum_field_from_cpu_field(
batch=child_b,
cpu_field="extend_prefix_lens_cpu",
device_field="extend_prefix_lens",
sum_field=None,
)
_, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"],
child_b.extend_prefix_lens,
child_b.extend_seq_lens,
child_b.extend_num_tokens,
)
@classmethod @classmethod
def filter_batch( def filter_batch(
cls, cls,
......
...@@ -5,7 +5,10 @@ from types import SimpleNamespace ...@@ -5,7 +5,10 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.two_batch_overlap import compute_split_seq_index from sglang.srt.two_batch_overlap import (
compute_split_seq_index,
compute_split_token_index,
)
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -73,35 +76,46 @@ class TestTwoBatchOverlap(unittest.TestCase): ...@@ -73,35 +76,46 @@ class TestTwoBatchOverlap(unittest.TestCase):
class TestTwoBatchOverlapUnitTest(unittest.TestCase): class TestTwoBatchOverlapUnitTest(unittest.TestCase):
# TODO change tests when having 6328 def test_compute_split_seq_and_token_index(self):
def test_compute_split_seq_index(self):
for num_tokens, expect in [ for num_tokens, expect in [
(0, 0), (0, 0),
(100, 50), (100, 50),
(99, 49), (99, 49),
]: ]:
actual = compute_split_seq_index( actual = compute_split_seq_index(
forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None forward_mode=ForwardMode.DECODE,
num_tokens=num_tokens,
extend_lens=None,
token_num_per_seq=1,
) )
self.assertEqual(actual, expect) self.assertEqual(actual, expect)
for extend_lens, expect in [ for extend_lens, expect in [
([], 0), ([], (0, 0)),
([42], 0), ([42], (0, 21)),
([42, 999], 1), ([42, 999], (1, 520)),
([999, 42], 1), ([999, 42], (0, 520)),
([4096, 4096, 4096, 4096], 2), ([498, 502], (1, 498)),
([4095, 4096, 4096, 4096, 1], 2), ([4096, 4096, 4096, 4096], (2, 8192)),
([1, 4095, 4096, 4096, 4096], 3), ([4095, 4096, 4096, 4096, 1], (2, 8191)),
([4097, 4096, 4096, 4095, 1], 2), ([1, 4095, 4096, 4096, 4096], (3, 8192)),
([1, 1, 1, 1, 99999], 4), ([4097, 4096, 4096, 4095, 1], (2, 8193)),
([99999, 1, 1, 1, 1], 1), ([1, 1, 1, 1, 99999], (4, 50001)),
([99999, 1, 1, 1, 1], (0, 50001)),
]: ]:
actual = compute_split_seq_index( actual_seq_idx = compute_split_seq_index(
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
num_tokens=None, num_tokens=None,
extend_lens=extend_lens, extend_lens=extend_lens,
token_num_per_seq=None,
)
actual_token_idx = compute_split_token_index(
split_seq_index=actual_seq_idx,
forward_mode=ForwardMode.EXTEND,
extend_seq_lens=extend_lens,
token_num_per_seq=None,
) )
actual = (actual_seq_idx, actual_token_idx)
print(f"{extend_lens=} {expect=} {actual=}") print(f"{extend_lens=} {expect=} {actual=}")
self.assertEqual(actual, expect) self.assertEqual(actual, expect)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment