Unverified Commit a191a0e4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Improve performance of two batch overlap in some imbalanced cases (#6593)

parent 8c7279c2
...@@ -40,13 +40,21 @@ def compute_split_seq_index( ...@@ -40,13 +40,21 @@ def compute_split_seq_index(
def _split_array_by_half_sum(arr: Sequence[int]) -> int: def _split_array_by_half_sum(arr: Sequence[int]) -> int:
overall_sum = sum(arr) overall_sum = sum(arr)
accumulator, split_index = 0, 0 left_sum = 0
for value in arr[:-1]: min_diff = float("inf")
accumulator += value best_index = 0
split_index += 1
if accumulator >= overall_sum // 2: for i in range(1, len(arr)):
left_sum += arr[i - 1]
right_sum = overall_sum - left_sum
diff = abs(left_sum - right_sum)
if diff <= min_diff:
min_diff = diff
best_index = i
else:
break break
return split_index
return best_index
def compute_split_token_index( def compute_split_token_index(
......
...@@ -4,6 +4,8 @@ from types import SimpleNamespace ...@@ -4,6 +4,8 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.two_batch_overlap import compute_split_seq_index
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -68,5 +70,39 @@ class TestTwoBatchOverlap(unittest.TestCase): ...@@ -68,5 +70,39 @@ class TestTwoBatchOverlap(unittest.TestCase):
self.assertGreater(metrics["score"], 0.5) self.assertGreater(metrics["score"], 0.5)
class TestTwoBatchOverlapUnitTest(unittest.TestCase):
# TODO change tests when having 6328
def test_compute_split_seq_index(self):
for num_tokens, expect in [
(0, 0),
(100, 50),
(99, 49),
]:
actual = compute_split_seq_index(
forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None
)
self.assertEqual(actual, expect)
for extend_lens, expect in [
([], 0),
([42], 0),
([42, 999], 1),
([999, 42], 1),
([4096, 4096, 4096, 4096], 2),
([4095, 4096, 4096, 4096, 1], 2),
([1, 4095, 4096, 4096, 4096], 3),
([4097, 4096, 4096, 4095, 1], 2),
([1, 1, 1, 1, 99999], 4),
([99999, 1, 1, 1, 1], 1),
]:
actual = compute_split_seq_index(
forward_mode=ForwardMode.EXTEND,
num_tokens=None,
extend_lens=extend_lens,
)
print(f"{extend_lens=} {expect=} {actual=}")
self.assertEqual(actual, expect)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment