Unverified Commit b5d70751 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Reordering extend logic fix (#27739)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent b8c48c5d
...@@ -53,7 +53,7 @@ REORDER_TEST_CASES = { ...@@ -53,7 +53,7 @@ REORDER_TEST_CASES = {
expected_modified=True, expected_modified=True,
), ),
"already_ordered": ReorderTestCase( "already_ordered": ReorderTestCase(
requests=[(1, 10), (1, 20), (100, 100), (200, 200)], requests=[(1, 10), (1, 20), (100, 100), (200, 0)],
expected_order=[0, 1, 2, 3], expected_order=[0, 1, 2, 3],
expected_modified=False, expected_modified=False,
), ),
...@@ -74,15 +74,30 @@ REORDER_TEST_CASES = { ...@@ -74,15 +74,30 @@ REORDER_TEST_CASES = {
expected_modified=True, expected_modified=True,
), ),
"decode_extend_prefill": ReorderTestCase( "decode_extend_prefill": ReorderTestCase(
requests=[(100, 100), (10, 50), (1, 10)], requests=[(100, 0), (10, 50), (1, 10)],
expected_order=[2, 1, 0], expected_order=[2, 1, 0],
expected_modified=True, expected_modified=True,
), ),
"extend_prefill_only": ReorderTestCase( "extend_prefill_only": ReorderTestCase(
requests=[(100, 100), (10, 50), (200, 200), (20, 75)], requests=[(100, 0), (10, 50), (200, 0), (20, 75)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
expected_modified=True, expected_modified=True,
), ),
"complicated_mixed_interleaved": ReorderTestCase(
requests=[
(1, 20),
(1, 50),
(374, 0),
(300, 20),
(1, 20),
(256, 0),
(1, 5),
(27, 0),
(1, 4),
],
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
expected_modified=True,
),
} }
......
...@@ -811,8 +811,8 @@ def reorder_batch_to_split_decodes_and_prefills( ...@@ -811,8 +811,8 @@ def reorder_batch_to_split_decodes_and_prefills(
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
is_decode = num_scheduled_tokens_np <= decode_threshold is_decode = num_scheduled_tokens_np <= decode_threshold
is_extend = (~is_decode) & (num_computed_tokens_np > num_scheduled_tokens_np) is_extend = (~is_decode) & (num_computed_tokens_np > 0)
is_prefill = (~is_decode) & (num_computed_tokens_np == num_scheduled_tokens_np) is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
# Desired order: decode → extend → prefill # Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
...@@ -832,11 +832,11 @@ def reorder_batch_to_split_decodes_and_prefills( ...@@ -832,11 +832,11 @@ def reorder_batch_to_split_decodes_and_prefills(
return False return False
# Extract indices that need swapping and sort by target region # Extract indices that need swapping and sort by target region
swap_indices = np.where(needs_swap)[0] orig_indices = np.where(needs_swap)[0]
sorted_order = np.argsort(req_regions[needs_swap], kind="stable") sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
dest_indices = swap_indices[sorted_order] src_indices = orig_indices[sorted_order]
src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)} src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
for src in src_dest_map: for src in src_dest_map:
dst = src_dest_map[src] dst = src_dest_map[src]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment