Unverified Commit e1d85e5c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Support distinguishing between short extends and decodes (#37303)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 79eb9369
......@@ -70,3 +70,15 @@ steps:
device: mi325_4
depends_on:
- image-build-amd
- label: V1 e2e (4xH100)
timeout_in_minutes: 60
device: h100
num_devices: 4
optional: true
source_file_dependencies:
- vllm/v1/attention/backends/utils.py
- vllm/v1/worker/gpu_model_runner.py
- tests/v1/e2e/test_hybrid_chunked_prefill.py
commands:
- pytest -v -s v1/e2e/test_hybrid_chunked_prefill.py
......@@ -10,9 +10,10 @@ from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_
class MockInputBatch:
def __init__(self, req_ids, num_computed_tokens_cpu):
def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
self.req_ids = req_ids
self.num_computed_tokens_cpu = num_computed_tokens_cpu
self.num_prompt_tokens = num_prompt_tokens
def swap_states(self, i, j):
self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i]
......@@ -20,6 +21,10 @@ class MockInputBatch:
self.num_computed_tokens_cpu[j],
self.num_computed_tokens_cpu[i],
)
self.num_prompt_tokens[i], self.num_prompt_tokens[j] = (
self.num_prompt_tokens[j],
self.num_prompt_tokens[i],
)
class MockSchedulerOutput:
......@@ -29,96 +34,139 @@ class MockSchedulerOutput:
@dataclass
class ReorderTestCase:
requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens)
# (num_scheduled_tokens, num_computed_tokens, num_prompt_tokens)
requests: list[tuple[int, int, int]]
expected_order: list[int]
expected_modified: bool
decode_threshold: int = 1
# Test cases for batch reordering
# Format: (num_scheduled, num_computed, num_prompt)
REORDER_TEST_CASES = {
"all_decodes": ReorderTestCase(
requests=[(1, 10), (1, 20), (1, 30)],
requests=[(1, 10, 10), (1, 20, 20), (1, 30, 30)],
expected_order=[0, 1, 2],
expected_modified=False,
),
"all_prefills": ReorderTestCase(
requests=[(100, 100), (200, 200), (300, 300)],
"all_long_extends": ReorderTestCase(
requests=[(100, 100, 100), (200, 200, 200), (300, 300, 300)],
expected_order=[0, 1, 2],
expected_modified=False,
),
"mixed_interleaved": ReorderTestCase(
requests=[(100, 100), (1, 10), (200, 200), (1, 20)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
"mixed_decodes_long_extends": ReorderTestCase(
requests=[(100, 100, 100), (1, 10, 10), (200, 200, 200), (1, 20, 20)],
expected_order=[3, 1, 2, 0],
expected_modified=True,
),
"already_ordered": ReorderTestCase(
requests=[(1, 10), (1, 20), (100, 100), (200, 0)],
requests=[(1, 10, 10), (1, 20, 20), (100, 100, 100), (200, 0, 200)],
expected_order=[0, 1, 2, 3],
expected_modified=False,
),
"single_request": ReorderTestCase(
requests=[(1, 10)],
requests=[(1, 10, 10)],
expected_order=[0],
expected_modified=False,
),
"higher_threshold": ReorderTestCase(
requests=[(2, 10), (3, 20), (5, 30), (6, 40)],
requests=[(2, 10, 10), (3, 20, 20), (5, 30, 30), (6, 40, 40)],
expected_order=[0, 1, 2, 3],
expected_modified=False,
decode_threshold=4,
),
"decodes_at_end": ReorderTestCase(
requests=[(100, 100), (200, 200), (1, 10), (1, 20)],
requests=[(100, 100, 100), (200, 200, 200), (1, 10, 10), (1, 20, 20)],
expected_order=[2, 3, 0, 1],
expected_modified=True,
),
"decode_extend_prefill": ReorderTestCase(
requests=[(100, 0), (10, 50), (1, 10)],
"decode_long_extend_prefill": ReorderTestCase(
requests=[(100, 0, 100), (10, 50, 50), (1, 10, 10)],
expected_order=[2, 1, 0],
expected_modified=True,
),
"extend_prefill_only": ReorderTestCase(
requests=[(100, 0), (10, 50), (200, 0), (20, 75)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
"long_extend_prefill_only": ReorderTestCase(
requests=[(100, 0, 100), (10, 50, 50), (200, 0, 200), (20, 75, 75)],
expected_order=[3, 1, 2, 0],
expected_modified=True,
),
"complicated_mixed_interleaved": ReorderTestCase(
"complicated_mixed": ReorderTestCase(
requests=[
(1, 20),
(1, 50),
(374, 0),
(300, 20),
(1, 20),
(256, 0),
(1, 5),
(27, 0),
(1, 4),
(1, 20, 20), # decode
(1, 50, 50), # decode
(374, 0, 374), # prefill
(300, 20, 20), # long_extend
(1, 20, 20), # decode
(256, 0, 256), # prefill
(1, 5, 5), # decode
(27, 0, 27), # prefill
(1, 4, 4), # decode
],
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
expected_modified=True,
),
"new_request_single_token_prefill": ReorderTestCase(
requests=[
(100, 0),
(1, 0), # New request with only 1 token (STILL prefill)
(50, 100),
(1, 10),
(100, 0, 100), # prefill
(1, 0, 1), # prefill (single token, still prefill)
(50, 100, 100), # long_extend
(1, 10, 10), # decode
],
# Only index 3 is a true decode (has num_computed_tokens > 0)
expected_order=[3, 2, 0, 1],
expected_modified=True,
),
"multiple_new_requests_single_token_prefill": ReorderTestCase(
requests=[
(1, 0), # New prefill (1 token, no computed)
(1, 0), # New prefill (1 token, no computed)
(1, 50),
(200, 0),
(1, 0, 1), # prefill
(1, 0, 1), # prefill
(1, 50, 50), # decode
(200, 0, 200), # prefill
],
expected_order=[2, 1, 0, 3],
expected_modified=True,
),
"four_way_already_ordered": ReorderTestCase(
requests=[
(1, 100, 100), # decode
(1, 50, 100), # short_extend
(10, 50, 100), # long_extend
(100, 0, 100), # prefill
],
expected_order=[0, 1, 2, 3],
expected_modified=False,
),
"four_way_needs_reorder": ReorderTestCase(
requests=[
(100, 0, 100), # prefill
(1, 50, 100), # short_extend
(1, 100, 100), # decode
(10, 50, 100), # long_extend
],
expected_order=[2, 1, 3, 0],
expected_modified=True,
),
"four_way_multiple_short_extends": ReorderTestCase(
requests=[
(2, 100, 100), # decode
(2, 50, 200), # short_extend
(2, 75, 150), # short_extend
(2, 200, 200), # decode
],
expected_order=[0, 3, 2, 1],
expected_modified=True,
decode_threshold=2,
),
"four_way_spec_decode_threshold": ReorderTestCase(
requests=[
(5, 100, 100), # decode
(5, 50, 100), # short_extend
(5, 0, 100), # prefill
(10, 50, 100), # long_extend
],
expected_order=[0, 1, 3, 2],
expected_modified=True,
decode_threshold=5,
),
}
......@@ -129,8 +177,9 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase)
req_ids = [f"r{i}" for i in range(len(test_case.requests))]
num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32)
num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)}
num_prompt_tokens = np.array([r[2] for r in test_case.requests], dtype=np.int32)
input_batch = MockInputBatch(req_ids, num_computed_tokens)
input_batch = MockInputBatch(req_ids, num_computed_tokens, num_prompt_tokens)
scheduler_output = MockSchedulerOutput(num_scheduled_tokens)
modified = reorder_batch_to_split_decodes_and_prefills(
......
......@@ -43,7 +43,7 @@ MESSAGES = [
pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]),
pytest.param(
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2),
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=4),
),
],
)
......@@ -68,7 +68,7 @@ def test_mtp_speculative_mixed_batch_short_prefill(
max_num_batched_tokens=chunk_size,
max_model_len=512,
enforce_eager=True,
tensor_parallel_size=2,
tensor_parallel_size=4,
trust_remote_code=True,
enable_chunked_prefill=True,
enable_prefix_caching=enable_prefix_caching,
......
......@@ -362,6 +362,11 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
is_prefilling: torch.Tensor | None = None
"""(batch_size,) bool tensor: True if request is still in prefill phase
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
......@@ -443,6 +448,7 @@ class CommonAttentionMetadata:
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
is_prefilling=maybe_slice_reqs(self.is_prefilling),
)
......
......@@ -358,7 +358,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold
common_attn_metadata,
decode_threshold=decode_threshold,
treat_short_extends_as_decodes=False,
)
)
......
......@@ -489,11 +489,15 @@ def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False,
treat_short_extends_as_decodes: bool = True,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
The batch is expected to be ordered as:
decode → short_extend → long_extend → prefill
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
......@@ -501,6 +505,9 @@ def split_decodes_and_prefills(
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.
treat_short_extends_as_decodes: If True (default), short extends
(query_len <= threshold but still prefilling) are counted as
decodes. If False, they are counted as prefills.
Returns:
num_decodes: The number of decode requests.
......@@ -513,8 +520,10 @@ def split_decodes_and_prefills(
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold and (
not require_uniform or decode_threshold <= 1
if (
max_query_len <= decode_threshold
and (not require_uniform or decode_threshold <= 1)
and treat_short_extends_as_decodes
):
return num_reqs, 0, num_tokens, 0
......@@ -533,11 +542,14 @@ def split_decodes_and_prefills(
else:
is_prefill = query_lens > decode_threshold
if not treat_short_extends_as_decodes:
assert common_attn_metadata.is_prefilling is not None
is_prefill |= common_attn_metadata.is_prefilling
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
......@@ -581,39 +593,52 @@ def reorder_batch_to_split_decodes_and_prefills(
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
The batch is reordered into 4 regions:
decode: (num_scheduled <= threshold AND is not prefilling)
short_extend: (num_scheduled <= threshold AND is chunked prefilling)
long_extend: (num_scheduled > threshold AND is chunked prefilling)
prefill: (num_computed == 0) # First chunks
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch into decode → extend → prefill order
# where:
# decode: request with num_scheduled_tokens <= decode_threshold
# extend: non-decode request with existing context
# prefill: non-decode request with no existing context
# NOTE for now we loosely use "decode" to mean requests where attention is
# likely memory-bound and "prefill" to mean requests where attention is
# likely compute-bound,
num_reqs = len(input_batch.req_ids)
num_scheduled_tokens = [
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
]
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
is_prefill = num_computed_tokens_np == 0
is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill)
is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill)
# Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
req_regions[is_extend] = 1
req_regions[is_prefill] = 2
num_prompt_tokens_np = input_batch.num_prompt_tokens[:num_reqs]
has_context = num_computed_tokens_np > 0
is_below_threshold = num_scheduled_tokens_np <= decode_threshold
done_prefilling = num_computed_tokens_np >= num_prompt_tokens_np
# Mutually exclusive categories (exactly one True per request):
# 1. No context yet -> prefill
# 2. Has context, above threshold -> long_extend
# 3. Has context, below threshold, still prefilling -> short_extend
# 4. Has context, below threshold, done prefilling -> decode
is_pure_prefill = ~has_context
is_long_extend = has_context & ~is_below_threshold
is_short_extend = has_context & is_below_threshold & ~done_prefilling
is_decode = has_context & is_below_threshold & done_prefilling
# Desired order: decode → short_extend → long_extend → prefill
req_regions = np.zeros(num_reqs, dtype=np.int32) # 0 = decode by default
req_regions[is_short_extend] = 1
req_regions[is_long_extend] = 2
req_regions[is_pure_prefill] = 3
num_decodes = int(is_decode.sum())
num_extends = int(is_extend.sum())
num_short_extends = int(is_short_extend.sum())
num_long_extends = int(is_long_extend.sum())
num_prefills = int(is_pure_prefill.sum())
target_regions = np.zeros(num_reqs, dtype=np.int32)
target_regions[num_decodes : num_decodes + num_extends] = 1
target_regions[num_decodes + num_extends :] = 2
target_regions = np.repeat(
[0, 1, 2, 3],
[num_decodes, num_short_extends, num_long_extends, num_prefills],
).astype(np.int32)
needs_swap = req_regions != target_regions
......
......@@ -134,7 +134,13 @@ class InputBatch:
pin_memory=pin_memory,
)
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy()
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
......
......@@ -740,19 +740,6 @@ class GPUModelRunner(
self.uniform_decode_query_len = 1 + self.num_spec_tokens
# When spec decode is active, the mamba backend classifies requests
# with query_len <= reorder_batch_threshold as "decodes". Prefill
# chunks that fall under this threshold get processed via the decode
# path, which stores intermediate states at sequential slots. We must
# set num_accepted_tokens to the chunk's query_len for those requests
# so the next iteration reads from the correct final-state slot.
# Prefills that went through the actual prefill path should keep the
# default value of 1 (the prefill path stores state at slot 0 only).
self.needs_prefill_as_decode_slots: bool = False
self.prefill_as_decode_num_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
# Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
......@@ -1369,16 +1356,6 @@ class GPUModelRunner(
.int()
.argmax(-1)
)
spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens)
if self.needs_prefill_as_decode_slots and spec_decode_active:
mamba_utils.update_accepted_tokens_for_prefill_as_decode(
self.input_batch,
self.prefill_as_decode_num_tokens,
self.num_accepted_tokens.gpu,
scheduler_output,
self.reorder_batch_threshold,
num_reqs,
)
if self.cache_config.mamba_cache_mode == "align":
for i, num_tokens in enumerate(
......@@ -1982,14 +1959,23 @@ class GPUModelRunner(
attn_gid = self.routed_experts_attn_gid
slot_mapping_attn = slot_mappings[attn_gid]
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
# Compute is_prefilling: True if request is still in prefill phase
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
# distinguish actual decodes from short extends.
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[
:num_reqs_padded
]
is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
seq_lens=self.seq_lens.gpu[:num_reqs_padded],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
_num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
],
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded,
max_query_len=max_query_len,
......@@ -1997,6 +1983,7 @@ class GPUModelRunner(
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
causal=True,
is_prefilling=is_prefilling,
)
if self.dcp_world_size > 1:
......@@ -2048,8 +2035,6 @@ class GPUModelRunner(
else 0
)
if isinstance(builder, Mamba2AttentionMetadataBuilder):
self.needs_prefill_as_decode_slots = True
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(
builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder)
......
......@@ -266,45 +266,3 @@ def postprocess_mamba(
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(copy_bufs)
def update_accepted_tokens_for_prefill_as_decode(
input_batch: GPUInputBatch,
prefill_as_decode_num_tokens: CpuGpuBuffer,
num_accepted_tokens_gpu: torch.Tensor,
scheduler_output: SchedulerOutput,
decode_qlen_threshold: int | None,
num_reqs: int,
):
"""
Adjusts num_accepted_tokens for prefill chunks processed via the decode path.
This ensures subsequent iterations read from the correct sequential state slot
instead of the default prefill slot 0. Not used by GDN attention, which manually
separates short prefills and short decodes when building the attention metadata.
"""
any_is_prefill = False
for i in range(num_reqs):
num_computed = input_batch.num_computed_tokens_cpu[i]
num_prompt = input_batch.num_prompt_tokens[i]
is_prefill = num_computed < num_prompt
req_id = input_batch.req_ids[i]
query_len = scheduler_output.num_scheduled_tokens[req_id]
if is_prefill:
classified_as_decode = (
decode_qlen_threshold is not None and query_len <= decode_qlen_threshold
)
num_tokens = query_len if classified_as_decode else 1
any_is_prefill = True
else:
num_tokens = -1
prefill_as_decode_num_tokens.np[i] = num_tokens
# We can skip the GPU transfer if there aren't any values to update
if any_is_prefill:
prefill_as_decode_num_tokens.copy_to_gpu(num_reqs)
num_accepted_tokens_gpu[:num_reqs] = torch.where(
prefill_as_decode_num_tokens.gpu[:num_reqs] != -1,
prefill_as_decode_num_tokens.gpu[:num_reqs],
num_accepted_tokens_gpu[:num_reqs],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment