Unverified Commit e1d85e5c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Support distinguishing between short extends and decodes (#37303)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 79eb9369
...@@ -70,3 +70,15 @@ steps: ...@@ -70,3 +70,15 @@ steps:
device: mi325_4 device: mi325_4
depends_on: depends_on:
- image-build-amd - image-build-amd
- label: V1 e2e (4xH100)
timeout_in_minutes: 60
device: h100
num_devices: 4
optional: true
source_file_dependencies:
- vllm/v1/attention/backends/utils.py
- vllm/v1/worker/gpu_model_runner.py
- tests/v1/e2e/test_hybrid_chunked_prefill.py
commands:
- pytest -v -s v1/e2e/test_hybrid_chunked_prefill.py
...@@ -10,9 +10,10 @@ from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_ ...@@ -10,9 +10,10 @@ from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_
class MockInputBatch: class MockInputBatch:
def __init__(self, req_ids, num_computed_tokens_cpu): def __init__(self, req_ids, num_computed_tokens_cpu, num_prompt_tokens):
self.req_ids = req_ids self.req_ids = req_ids
self.num_computed_tokens_cpu = num_computed_tokens_cpu self.num_computed_tokens_cpu = num_computed_tokens_cpu
self.num_prompt_tokens = num_prompt_tokens
def swap_states(self, i, j): def swap_states(self, i, j):
self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i] self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i]
...@@ -20,6 +21,10 @@ class MockInputBatch: ...@@ -20,6 +21,10 @@ class MockInputBatch:
self.num_computed_tokens_cpu[j], self.num_computed_tokens_cpu[j],
self.num_computed_tokens_cpu[i], self.num_computed_tokens_cpu[i],
) )
self.num_prompt_tokens[i], self.num_prompt_tokens[j] = (
self.num_prompt_tokens[j],
self.num_prompt_tokens[i],
)
class MockSchedulerOutput: class MockSchedulerOutput:
...@@ -29,96 +34,139 @@ class MockSchedulerOutput: ...@@ -29,96 +34,139 @@ class MockSchedulerOutput:
@dataclass @dataclass
class ReorderTestCase: class ReorderTestCase:
requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens) # (num_scheduled_tokens, num_computed_tokens, num_prompt_tokens)
requests: list[tuple[int, int, int]]
expected_order: list[int] expected_order: list[int]
expected_modified: bool expected_modified: bool
decode_threshold: int = 1 decode_threshold: int = 1
# Test cases for batch reordering # Test cases for batch reordering
# Format: (num_scheduled, num_computed, num_prompt)
REORDER_TEST_CASES = { REORDER_TEST_CASES = {
"all_decodes": ReorderTestCase( "all_decodes": ReorderTestCase(
requests=[(1, 10), (1, 20), (1, 30)], requests=[(1, 10, 10), (1, 20, 20), (1, 30, 30)],
expected_order=[0, 1, 2], expected_order=[0, 1, 2],
expected_modified=False, expected_modified=False,
), ),
"all_prefills": ReorderTestCase( "all_long_extends": ReorderTestCase(
requests=[(100, 100), (200, 200), (300, 300)], requests=[(100, 100, 100), (200, 200, 200), (300, 300, 300)],
expected_order=[0, 1, 2], expected_order=[0, 1, 2],
expected_modified=False, expected_modified=False,
), ),
"mixed_interleaved": ReorderTestCase( "mixed_decodes_long_extends": ReorderTestCase(
requests=[(100, 100), (1, 10), (200, 200), (1, 20)], requests=[(100, 100, 100), (1, 10, 10), (200, 200, 200), (1, 20, 20)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place expected_order=[3, 1, 2, 0],
expected_modified=True, expected_modified=True,
), ),
"already_ordered": ReorderTestCase( "already_ordered": ReorderTestCase(
requests=[(1, 10), (1, 20), (100, 100), (200, 0)], requests=[(1, 10, 10), (1, 20, 20), (100, 100, 100), (200, 0, 200)],
expected_order=[0, 1, 2, 3], expected_order=[0, 1, 2, 3],
expected_modified=False, expected_modified=False,
), ),
"single_request": ReorderTestCase( "single_request": ReorderTestCase(
requests=[(1, 10)], requests=[(1, 10, 10)],
expected_order=[0], expected_order=[0],
expected_modified=False, expected_modified=False,
), ),
"higher_threshold": ReorderTestCase( "higher_threshold": ReorderTestCase(
requests=[(2, 10), (3, 20), (5, 30), (6, 40)], requests=[(2, 10, 10), (3, 20, 20), (5, 30, 30), (6, 40, 40)],
expected_order=[0, 1, 2, 3], expected_order=[0, 1, 2, 3],
expected_modified=False, expected_modified=False,
decode_threshold=4, decode_threshold=4,
), ),
"decodes_at_end": ReorderTestCase( "decodes_at_end": ReorderTestCase(
requests=[(100, 100), (200, 200), (1, 10), (1, 20)], requests=[(100, 100, 100), (200, 200, 200), (1, 10, 10), (1, 20, 20)],
expected_order=[2, 3, 0, 1], expected_order=[2, 3, 0, 1],
expected_modified=True, expected_modified=True,
), ),
"decode_extend_prefill": ReorderTestCase( "decode_long_extend_prefill": ReorderTestCase(
requests=[(100, 0), (10, 50), (1, 10)], requests=[(100, 0, 100), (10, 50, 50), (1, 10, 10)],
expected_order=[2, 1, 0], expected_order=[2, 1, 0],
expected_modified=True, expected_modified=True,
), ),
"extend_prefill_only": ReorderTestCase( "long_extend_prefill_only": ReorderTestCase(
requests=[(100, 0), (10, 50), (200, 0), (20, 75)], requests=[(100, 0, 100), (10, 50, 50), (200, 0, 200), (20, 75, 75)],
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place expected_order=[3, 1, 2, 0],
expected_modified=True, expected_modified=True,
), ),
"complicated_mixed_interleaved": ReorderTestCase( "complicated_mixed": ReorderTestCase(
requests=[ requests=[
(1, 20), (1, 20, 20), # decode
(1, 50), (1, 50, 50), # decode
(374, 0), (374, 0, 374), # prefill
(300, 20), (300, 20, 20), # long_extend
(1, 20), (1, 20, 20), # decode
(256, 0), (256, 0, 256), # prefill
(1, 5), (1, 5, 5), # decode
(27, 0), (27, 0, 27), # prefill
(1, 4), (1, 4, 4), # decode
], ],
expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5], expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5],
expected_modified=True, expected_modified=True,
), ),
"new_request_single_token_prefill": ReorderTestCase( "new_request_single_token_prefill": ReorderTestCase(
requests=[ requests=[
(100, 0), (100, 0, 100), # prefill
(1, 0), # New request with only 1 token (STILL prefill) (1, 0, 1), # prefill (single token, still prefill)
(50, 100), (50, 100, 100), # long_extend
(1, 10), (1, 10, 10), # decode
], ],
# Only index 3 is a true decode (has num_computed_tokens > 0)
expected_order=[3, 2, 0, 1], expected_order=[3, 2, 0, 1],
expected_modified=True, expected_modified=True,
), ),
"multiple_new_requests_single_token_prefill": ReorderTestCase( "multiple_new_requests_single_token_prefill": ReorderTestCase(
requests=[ requests=[
(1, 0), # New prefill (1 token, no computed) (1, 0, 1), # prefill
(1, 0), # New prefill (1 token, no computed) (1, 0, 1), # prefill
(1, 50), (1, 50, 50), # decode
(200, 0), (200, 0, 200), # prefill
], ],
expected_order=[2, 1, 0, 3], expected_order=[2, 1, 0, 3],
expected_modified=True, expected_modified=True,
), ),
"four_way_already_ordered": ReorderTestCase(
requests=[
(1, 100, 100), # decode
(1, 50, 100), # short_extend
(10, 50, 100), # long_extend
(100, 0, 100), # prefill
],
expected_order=[0, 1, 2, 3],
expected_modified=False,
),
"four_way_needs_reorder": ReorderTestCase(
requests=[
(100, 0, 100), # prefill
(1, 50, 100), # short_extend
(1, 100, 100), # decode
(10, 50, 100), # long_extend
],
expected_order=[2, 1, 3, 0],
expected_modified=True,
),
"four_way_multiple_short_extends": ReorderTestCase(
requests=[
(2, 100, 100), # decode
(2, 50, 200), # short_extend
(2, 75, 150), # short_extend
(2, 200, 200), # decode
],
expected_order=[0, 3, 2, 1],
expected_modified=True,
decode_threshold=2,
),
"four_way_spec_decode_threshold": ReorderTestCase(
requests=[
(5, 100, 100), # decode
(5, 50, 100), # short_extend
(5, 0, 100), # prefill
(10, 50, 100), # long_extend
],
expected_order=[0, 1, 3, 2],
expected_modified=True,
decode_threshold=5,
),
} }
...@@ -129,8 +177,9 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase) ...@@ -129,8 +177,9 @@ def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase)
req_ids = [f"r{i}" for i in range(len(test_case.requests))] req_ids = [f"r{i}" for i in range(len(test_case.requests))]
num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32) num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32)
num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)} num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)}
num_prompt_tokens = np.array([r[2] for r in test_case.requests], dtype=np.int32)
input_batch = MockInputBatch(req_ids, num_computed_tokens) input_batch = MockInputBatch(req_ids, num_computed_tokens, num_prompt_tokens)
scheduler_output = MockSchedulerOutput(num_scheduled_tokens) scheduler_output = MockSchedulerOutput(num_scheduled_tokens)
modified = reorder_batch_to_split_decodes_and_prefills( modified = reorder_batch_to_split_decodes_and_prefills(
......
...@@ -43,7 +43,7 @@ MESSAGES = [ ...@@ -43,7 +43,7 @@ MESSAGES = [
pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]), pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]),
pytest.param( pytest.param(
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8", "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2), marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=4),
), ),
], ],
) )
...@@ -68,7 +68,7 @@ def test_mtp_speculative_mixed_batch_short_prefill( ...@@ -68,7 +68,7 @@ def test_mtp_speculative_mixed_batch_short_prefill(
max_num_batched_tokens=chunk_size, max_num_batched_tokens=chunk_size,
max_model_len=512, max_model_len=512,
enforce_eager=True, enforce_eager=True,
tensor_parallel_size=2, tensor_parallel_size=4,
trust_remote_code=True, trust_remote_code=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
enable_prefix_caching=enable_prefix_caching, enable_prefix_caching=enable_prefix_caching,
......
...@@ -362,6 +362,11 @@ class CommonAttentionMetadata: ...@@ -362,6 +362,11 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world""" """Sequence lengths of the local rank in decode context parallelism world"""
is_prefilling: torch.Tensor | None = None
"""(batch_size,) bool tensor: True if request is still in prefill phase
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None _seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None _num_computed_tokens_cpu: torch.Tensor | None = None
...@@ -443,6 +448,7 @@ class CommonAttentionMetadata: ...@@ -443,6 +448,7 @@ class CommonAttentionMetadata:
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu), encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens), dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu), dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
is_prefilling=maybe_slice_reqs(self.is_prefilling),
) )
......
...@@ -358,7 +358,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -358,7 +358,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold common_attn_metadata,
decode_threshold=decode_threshold,
treat_short_extends_as_decodes=False,
) )
) )
......
...@@ -489,11 +489,15 @@ def split_decodes_and_prefills( ...@@ -489,11 +489,15 @@ def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1, decode_threshold: int = 1,
require_uniform: bool = False, require_uniform: bool = False,
treat_short_extends_as_decodes: bool = True,
) -> tuple[int, int, int, int]: ) -> tuple[int, int, int, int]:
""" """
Assuming a reordered batch, finds the boundary between prefill and decode Assuming a reordered batch, finds the boundary between prefill and decode
requests. requests.
The batch is expected to be ordered as:
decode → short_extend → long_extend → prefill
Args: Args:
common_attn_metadata: CommonAttentionMetadata object containing the common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata. batch metadata.
...@@ -501,6 +505,9 @@ def split_decodes_and_prefills( ...@@ -501,6 +505,9 @@ def split_decodes_and_prefills(
require_uniform: If True, requires that all decode requests have the require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity. even if they are <= decode_threshold, in order to ensure uniformity.
treat_short_extends_as_decodes: If True (default), short extends
(query_len <= threshold but still prefilling) are counted as
decodes. If False, they are counted as prefills.
Returns: Returns:
num_decodes: The number of decode requests. num_decodes: The number of decode requests.
...@@ -513,8 +520,10 @@ def split_decodes_and_prefills( ...@@ -513,8 +520,10 @@ def split_decodes_and_prefills(
num_tokens = common_attn_metadata.num_actual_tokens num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold and ( if (
not require_uniform or decode_threshold <= 1 max_query_len <= decode_threshold
and (not require_uniform or decode_threshold <= 1)
and treat_short_extends_as_decodes
): ):
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
...@@ -533,11 +542,14 @@ def split_decodes_and_prefills( ...@@ -533,11 +542,14 @@ def split_decodes_and_prefills(
else: else:
is_prefill = query_lens > decode_threshold is_prefill = query_lens > decode_threshold
if not treat_short_extends_as_decodes:
assert common_attn_metadata.is_prefilling is not None
is_prefill |= common_attn_metadata.is_prefilling
if not torch.any(is_prefill): if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item() first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill num_decodes = first_prefill
num_prefills = num_reqs - num_decodes num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item() num_decode_tokens = query_start_loc[first_prefill].item()
...@@ -581,39 +593,52 @@ def reorder_batch_to_split_decodes_and_prefills( ...@@ -581,39 +593,52 @@ def reorder_batch_to_split_decodes_and_prefills(
Reorders the batch to split into prefill and decode requests; places all Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch. requests with <= decode_threshold tokens at the front of the batch.
The batch is reordered into 4 regions:
decode: (num_scheduled <= threshold AND is not prefilling)
short_extend: (num_scheduled <= threshold AND is chunked prefilling)
long_extend: (num_scheduled > threshold AND is chunked prefilling)
prefill: (num_computed == 0) # First chunks
Returns: Returns:
True if the batch was modified, False otherwise. True if the batch was modified, False otherwise.
""" """
# We now want to reorder the batch into decode → extend → prefill order
# where:
# decode: request with num_scheduled_tokens <= decode_threshold
# extend: non-decode request with existing context
# prefill: non-decode request with no existing context
# NOTE for now we loosely use "decode" to mean requests where attention is
# likely memory-bound and "prefill" to mean requests where attention is
# likely compute-bound,
num_reqs = len(input_batch.req_ids) num_reqs = len(input_batch.req_ids)
num_scheduled_tokens = [ num_scheduled_tokens = [
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
] ]
num_scheduled_tokens_np = np.array(num_scheduled_tokens) num_scheduled_tokens_np = np.array(num_scheduled_tokens)
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
num_prompt_tokens_np = input_batch.num_prompt_tokens[:num_reqs]
is_prefill = num_computed_tokens_np == 0
is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill) has_context = num_computed_tokens_np > 0
is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill) is_below_threshold = num_scheduled_tokens_np <= decode_threshold
done_prefilling = num_computed_tokens_np >= num_prompt_tokens_np
# Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default # Mutually exclusive categories (exactly one True per request):
req_regions[is_extend] = 1 # 1. No context yet -> prefill
req_regions[is_prefill] = 2 # 2. Has context, above threshold -> long_extend
# 3. Has context, below threshold, still prefilling -> short_extend
# 4. Has context, below threshold, done prefilling -> decode
is_pure_prefill = ~has_context
is_long_extend = has_context & ~is_below_threshold
is_short_extend = has_context & is_below_threshold & ~done_prefilling
is_decode = has_context & is_below_threshold & done_prefilling
# Desired order: decode → short_extend → long_extend → prefill
req_regions = np.zeros(num_reqs, dtype=np.int32) # 0 = decode by default
req_regions[is_short_extend] = 1
req_regions[is_long_extend] = 2
req_regions[is_pure_prefill] = 3
num_decodes = int(is_decode.sum()) num_decodes = int(is_decode.sum())
num_extends = int(is_extend.sum()) num_short_extends = int(is_short_extend.sum())
num_long_extends = int(is_long_extend.sum())
num_prefills = int(is_pure_prefill.sum())
target_regions = np.zeros(num_reqs, dtype=np.int32) target_regions = np.repeat(
target_regions[num_decodes : num_decodes + num_extends] = 1 [0, 1, 2, 3],
target_regions[num_decodes + num_extends :] = 2 [num_decodes, num_short_extends, num_long_extends, num_prefills],
).astype(np.int32)
needs_swap = req_regions != target_regions needs_swap = req_regions != target_regions
......
...@@ -134,7 +134,13 @@ class InputBatch: ...@@ -134,7 +134,13 @@ class InputBatch:
pin_memory=pin_memory, pin_memory=pin_memory,
) )
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy() self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy()
self.num_computed_tokens_cpu_tensor = torch.zeros( self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,), (max_num_reqs,),
device="cpu", device="cpu",
......
...@@ -740,19 +740,6 @@ class GPUModelRunner( ...@@ -740,19 +740,6 @@ class GPUModelRunner(
self.uniform_decode_query_len = 1 + self.num_spec_tokens self.uniform_decode_query_len = 1 + self.num_spec_tokens
# When spec decode is active, the mamba backend classifies requests
# with query_len <= reorder_batch_threshold as "decodes". Prefill
# chunks that fall under this threshold get processed via the decode
# path, which stores intermediate states at sequential slots. We must
# set num_accepted_tokens to the chunk's query_len for those requests
# so the next iteration reads from the correct final-state slot.
# Prefills that went through the actual prefill path should keep the
# default value of 1 (the prefill path stores state at slot 0 only).
self.needs_prefill_as_decode_slots: bool = False
self.prefill_as_decode_num_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
# Cudagraph dispatcher for runtime cudagraph dispatching. # Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
...@@ -1369,16 +1356,6 @@ class GPUModelRunner( ...@@ -1369,16 +1356,6 @@ class GPUModelRunner(
.int() .int()
.argmax(-1) .argmax(-1)
) )
spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens)
if self.needs_prefill_as_decode_slots and spec_decode_active:
mamba_utils.update_accepted_tokens_for_prefill_as_decode(
self.input_batch,
self.prefill_as_decode_num_tokens,
self.num_accepted_tokens.gpu,
scheduler_output,
self.reorder_batch_threshold,
num_reqs,
)
if self.cache_config.mamba_cache_mode == "align": if self.cache_config.mamba_cache_mode == "align":
for i, num_tokens in enumerate( for i, num_tokens in enumerate(
...@@ -1982,14 +1959,23 @@ class GPUModelRunner( ...@@ -1982,14 +1959,23 @@ class GPUModelRunner(
attn_gid = self.routed_experts_attn_gid attn_gid = self.routed_experts_attn_gid
slot_mapping_attn = slot_mappings[attn_gid] slot_mapping_attn = slot_mappings[attn_gid]
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
# Compute is_prefilling: True if request is still in prefill phase
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
# distinguish actual decodes from short extends.
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[
:num_reqs_padded
]
is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
seq_lens=self.seq_lens.gpu[:num_reqs_padded], seq_lens=self.seq_lens.gpu[:num_reqs_padded],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
_num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ _num_computed_tokens_cpu=num_computed_tokens_cpu,
:num_reqs_padded
],
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded, num_actual_tokens=num_tokens_padded,
max_query_len=max_query_len, max_query_len=max_query_len,
...@@ -1997,6 +1983,7 @@ class GPUModelRunner( ...@@ -1997,6 +1983,7 @@ class GPUModelRunner(
block_table_tensor=block_table_gid_0, block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0, slot_mapping=slot_mapping_gid_0,
causal=True, causal=True,
is_prefilling=is_prefilling,
) )
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
...@@ -2048,8 +2035,6 @@ class GPUModelRunner( ...@@ -2048,8 +2035,6 @@ class GPUModelRunner(
else 0 else 0
) )
if isinstance(builder, Mamba2AttentionMetadataBuilder):
self.needs_prefill_as_decode_slots = True
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance( if use_spec_decode and isinstance(
builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder)
......
...@@ -266,45 +266,3 @@ def postprocess_mamba( ...@@ -266,45 +266,3 @@ def postprocess_mamba(
if src_block_idx == dest_block_idx: if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1 num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(copy_bufs) do_mamba_copy_block(copy_bufs)
def update_accepted_tokens_for_prefill_as_decode(
input_batch: GPUInputBatch,
prefill_as_decode_num_tokens: CpuGpuBuffer,
num_accepted_tokens_gpu: torch.Tensor,
scheduler_output: SchedulerOutput,
decode_qlen_threshold: int | None,
num_reqs: int,
):
"""
Adjusts num_accepted_tokens for prefill chunks processed via the decode path.
This ensures subsequent iterations read from the correct sequential state slot
instead of the default prefill slot 0. Not used by GDN attention, which manually
separates short prefills and short decodes when building the attention metadata.
"""
any_is_prefill = False
for i in range(num_reqs):
num_computed = input_batch.num_computed_tokens_cpu[i]
num_prompt = input_batch.num_prompt_tokens[i]
is_prefill = num_computed < num_prompt
req_id = input_batch.req_ids[i]
query_len = scheduler_output.num_scheduled_tokens[req_id]
if is_prefill:
classified_as_decode = (
decode_qlen_threshold is not None and query_len <= decode_qlen_threshold
)
num_tokens = query_len if classified_as_decode else 1
any_is_prefill = True
else:
num_tokens = -1
prefill_as_decode_num_tokens.np[i] = num_tokens
# We can skip the GPU transfer if there aren't any values to update
if any_is_prefill:
prefill_as_decode_num_tokens.copy_to_gpu(num_reqs)
num_accepted_tokens_gpu[:num_reqs] = torch.where(
prefill_as_decode_num_tokens.gpu[:num_reqs] != -1,
prefill_as_decode_num_tokens.gpu[:num_reqs],
num_accepted_tokens_gpu[:num_reqs],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment