Unverified Commit fe85a92e authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Avoid seq_lens_cpu GPU->CPU sync (#40654)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 62b1bbe4
......@@ -107,6 +107,7 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
......
......@@ -241,11 +241,13 @@ def forward_attention(
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
seq_lens_cpu = seq_lens.cpu()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
_seq_lens_cpu=seq_lens.cpu(),
seq_lens_cpu_upper_bound=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
......
......@@ -90,15 +90,23 @@ def create_cross_attention_backend(
assert new_metadata.encoder_seq_lens_cpu is not None
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
num_cache_decodes = (
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
# Any computed tokens indicates decode step>1 (no chunked prefill).
# The upper bound is exact for this `> 0` test - prefill rows have
# num_computed == 0 and decode rows have num_computed > 0.
query_lens_cpu = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
num_computed_tokens_cpu = (
common_attn_metadata.seq_lens_cpu_upper_bound - query_lens_cpu
)
num_cache_decodes = (num_computed_tokens_cpu > 0).sum().item()
if num_cache_decodes > 0:
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
num_tokens = num_computed_tokens_cpu.numpy()
new_metadata.encoder_seq_lens_cpu = np.where(
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
)
......
......@@ -1822,13 +1822,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Upper bound is exact for prefill rows (no D2H sync).
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
assert seq_lens_cpu is not None
prefill_query_lens_cpu = (
query_start_loc_cpu[reqs_start + 1 : num_reqs + 1]
- query_start_loc_cpu[reqs_start:num_reqs]
)
context_lens_cpu = (
seq_lens_cpu[reqs_start:num_reqs] - prefill_query_lens_cpu
)
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = (
......
......@@ -397,6 +397,12 @@ class CommonAttentionMetadata:
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""
seq_lens_cpu_upper_bound: torch.Tensor | None = None
"""(batch_size,) CPU upper bound on seq_lens. Precise for prefill rows
and for all rows outside async spec decode; optimistic for async-spec
decode rows (assumes every draft was accepted). Not safe for kernels
that need exact per-row context lengths on decode rows."""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
......
......@@ -782,10 +782,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> FlexAttentionMetadata:
# Use actual max_seq_len instead of max_model_len to avoid
# torch.compile recompilation during CUDA graph capture.
common_attn_metadata.max_seq_len = (
common_attn_metadata.seq_lens_cpu.max().item()
# Use actual max_seq_len (not max_model_len) to avoid torch.compile
# recompilation during CUDA graph capture.
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
common_attn_metadata.max_seq_len = int(
common_attn_metadata.seq_lens_cpu_upper_bound.max().item()
)
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
......
......@@ -364,7 +364,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below), so no D2H sync is needed.
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
assert seq_lens_cpu is not None
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
......
......@@ -554,8 +554,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below).
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
chunk_specs = split_indexer_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
seq_lens_cpu[num_decodes:],
prefill_query_lens_cpu,
self.max_prefill_buffer_size,
max_logits_bytes,
......@@ -566,7 +570,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
req_slice,
query_slice,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
seq_lens_cpu,
common_attn_metadata.block_table_tensor,
skip_kv_gather=query_slice.start > 0,
)
......
......@@ -356,6 +356,7 @@ def make_local_attention_virtual_batches(
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
seq_lens_cpu_upper_bound=seq_lens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
), make_block_table
......@@ -414,6 +415,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
......@@ -445,7 +447,11 @@ def split_decodes_prefills_and_extends(
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
# Upper bound is exact for prefill rows; decode rows still satisfy
# seq_len > query_len under the optimistic bound, so `seq_lens ==
# query_lens` identifies prefills correctly either way.
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
seq_lens = common_attn_metadata.seq_lens_cpu_upper_bound
if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
......
......@@ -151,6 +151,12 @@ class DFlashProposer(SpecDecodeBaseProposer):
if has_num_rejected:
effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu
# Skip num_rejected_tokens (GPU-only); overestimating is fine here.
new_seq_lens_cpu_upper_bound = (
cad.seq_lens_cpu_upper_bound + num_query_per_req
if cad.seq_lens_cpu_upper_bound is not None
else None
)
new_cad = CommonAttentionMetadata(
query_start_loc=new_query_start_loc,
seq_lens=effective_seq_lens + num_query_per_req,
......@@ -160,6 +166,7 @@ class DFlashProposer(SpecDecodeBaseProposer):
),
_seq_lens_cpu=None,
_num_computed_tokens_cpu=None,
seq_lens_cpu_upper_bound=new_seq_lens_cpu_upper_bound,
num_reqs=cad.num_reqs,
num_actual_tokens=num_query_total,
max_query_len=num_query_per_req,
......
......@@ -593,6 +593,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata._seq_lens_cpu += 1
if common_attn_metadata._num_computed_tokens_cpu is not None:
common_attn_metadata._num_computed_tokens_cpu += 1
if common_attn_metadata.seq_lens_cpu_upper_bound is not None:
common_attn_metadata.seq_lens_cpu_upper_bound += 1
# Rebuild attention metadata
_, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(
......@@ -959,6 +961,7 @@ class SpecDecodeBaseProposer:
query_start_loc_cpu=query_start_loc_cpu,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
......@@ -1183,7 +1186,11 @@ class SpecDecodeBaseProposer:
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
# upper_bound - rejected = actual post-rejection seq_lens (no D2H sync).
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
new_seq_lens_cpu = (
common_attn_metadata.seq_lens_cpu_upper_bound - num_rejected_tokens
)
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
......@@ -1237,6 +1244,7 @@ class SpecDecodeBaseProposer:
query_start_loc_cpu=new_query_start_loc_cpu,
_seq_lens_cpu=new_seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
seq_lens_cpu_upper_bound=new_seq_lens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
......
......@@ -227,12 +227,15 @@ def build_attn_metadata(
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
seq_lens_cpu_upper_bound: torch.Tensor | None = None,
dcp_local_seq_lens: torch.Tensor | None = None,
encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
if seq_lens_cpu_upper_bound is not None:
seq_lens_cpu_upper_bound = seq_lens_cpu_upper_bound[:num_reqs]
attn_metadata: dict[str, Any] = {}
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
......@@ -244,6 +247,7 @@ def build_attn_metadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
max_seq_len=max_seq_len,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
......
......@@ -60,6 +60,8 @@ class InputBatch:
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_reqs] CPU upper bound on seq_lens (see CommonAttentionMetadata).
seq_lens_cpu_upper_bound: torch.Tensor
# [num_reqs]
dcp_local_seq_lens: torch.Tensor | None
......@@ -121,6 +123,8 @@ class InputBatch:
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
# Dummy: seq_len == query_len (fresh-prefill shape).
seq_lens_cpu_upper_bound = torch.from_numpy(num_scheduled_tokens.copy())
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
......@@ -136,6 +140,7 @@ class InputBatch:
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=None,
input_ids=input_ids,
positions=positions,
......
......@@ -799,6 +799,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits,
)
# CPU upper bound on seq_lens; padded entries left at zero.
seq_lens_cpu_upper_bound_np = np.zeros(num_reqs_padded, dtype=np.int32)
np.add(
self.req_states.num_computed_tokens_np[idx_mapping_np],
num_scheduled_tokens,
out=seq_lens_cpu_upper_bound_np[:num_reqs],
)
seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np)
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
......@@ -814,6 +823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
......@@ -927,6 +937,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self.req_states.num_computed_tokens_np[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)
@torch.inference_mode()
def execute_model(
......@@ -1297,6 +1311,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self.req_states.num_computed_tokens_np[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)
########### EPLB methods start ###########
@property
......
......@@ -173,6 +173,12 @@ class DefaultModelState(ModelState):
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
# Capture with worst-case max_seq_len so the graph is valid at any replay.
max_seq_len = self.max_model_len
else:
max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item())
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
......@@ -181,10 +187,11 @@ class DefaultModelState(ModelState):
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata
......@@ -117,6 +117,11 @@ class WhisperModelState(ModelState):
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
max_seq_len = self.max_model_len
else:
max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item())
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
......@@ -125,10 +130,11 @@ class WhisperModelState(ModelState):
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
encoder_seq_lens=encoder_seq_lens,
)
......
......@@ -57,6 +57,8 @@ class RequestState:
self.num_computed_tokens = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Optimistic CPU mirror of num_computed_tokens (upper bound on GPU value).
self.num_computed_tokens_np = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
......@@ -100,6 +102,7 @@ class RequestState:
self.total_len.stage_write_elem(req_idx, prefill_len)
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens_np[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
if num_computed_tokens > 0 and num_computed_tokens <= prefill_len:
......
......@@ -2155,6 +2155,7 @@ class GPUModelRunner(
:num_reqs_padded
]
seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded]
seq_lens_cpu_upper_bound = seq_lens_cpu
# is_prefilling: True if request is still in prefill phase.
# Used by mamba backends to distinguish actual decodes from
......@@ -2172,6 +2173,7 @@ class GPUModelRunner(
seq_lens=self.seq_lens[:num_reqs_padded],
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded,
max_query_len=max_query_len,
......
......@@ -177,7 +177,22 @@ def _make_metadata_with_slice(
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
# Read raw fields to avoid triggering the deprecated D2H-syncing properties.
seq_lens_cpu = (
attn_metadata._seq_lens_cpu[request_slice]
if attn_metadata._seq_lens_cpu is not None
else None
)
seq_lens_cpu_upper_bound = (
attn_metadata.seq_lens_cpu_upper_bound[request_slice]
if attn_metadata.seq_lens_cpu_upper_bound is not None
else None
)
num_computed_tokens_cpu = (
attn_metadata._num_computed_tokens_cpu[request_slice]
if attn_metadata._num_computed_tokens_cpu is not None
else None
)
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
......@@ -190,12 +205,16 @@ def _make_metadata_with_slice(
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
if seq_lens_cpu is not None:
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens_cpu[-1] -= tokens_skipped
if seq_lens_cpu_upper_bound is not None:
seq_lens_cpu_upper_bound = seq_lens_cpu_upper_bound.clone()
seq_lens_cpu_upper_bound[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
assert seq_lens_cpu_upper_bound is not None
max_seq_len = int(seq_lens_cpu_upper_bound.max())
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
......@@ -221,6 +240,7 @@ def _make_metadata_with_slice(
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment