Unverified Commit 9f39b380 authored by Rishi Puri's avatar Rishi Puri Committed by GitHub
Browse files

[Bugfix] Fix spec decode test failures on Blackwell (SM100+) (#39546)


Signed-off-by: default avatarStefano Castagnetta <scastagnetta@nvidia.com>
Signed-off-by: default avatarRishi Puri <puririshi98@berkeley.edu>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarStefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 9a6a66f3
......@@ -12,6 +12,17 @@ steps:
commands:
- pytest -v -s v1/e2e/spec_decode -k "eagle_correctness"
- label: Spec Decode Eagle Nightly B200
timeout_in_minutes: 30
device: b200
optional: true
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- tests/v1/e2e/spec_decode/
commands:
- pytest -v -s v1/e2e/spec_decode -k "eagle_correctness"
- label: Spec Decode Speculators + MTP
timeout_in_minutes: 30
device: h200_18gb
......@@ -23,6 +34,18 @@ steps:
commands:
- pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness"
- label: Spec Decode Speculators + MTP Nightly B200
timeout_in_minutes: 30
device: b200
optional: true
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- vllm/transformers_utils/configs/speculators/
- tests/v1/e2e/spec_decode/
commands:
- pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness"
- label: Spec Decode Ngram + Suffix
timeout_in_minutes: 30
device: h200_18gb
......@@ -43,6 +66,17 @@ steps:
commands:
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
- label: Spec Decode Draft Model Nightly B200
timeout_in_minutes: 30
device: b200
optional: true
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- tests/v1/e2e/spec_decode/
commands:
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
- label: DFlash Speculators Correctness
timeout_in_minutes: 30
device: h100
......
......@@ -920,9 +920,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
num_decodes == 0 or decode_use_trtllm
)
is_only_trtllm_decode = num_prefills == 0 and (
num_decodes > 0 and decode_use_trtllm
)
if not all_uses_trtllm:
if self.has_sinks:
......@@ -968,7 +965,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
# When all attention (both prefill and decode) uses TRTLLM,
# seq_lens_cpu is not needed since TRTLLM paths use GPU tensors
# (block_tables, seq_lens) directly.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not all_uses_trtllm
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = (
......@@ -1006,7 +1006,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_blocks_np -= num_common_kv_blocks
# Compute paged_kv_indices if necessary
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
# paged_kv_indices is only needed for FlashInfer native paths;
# TRTLLM paths use block_tables directly on GPU.
needs_paged_kv_indices = use_cascade or not all_uses_trtllm
if needs_paged_kv_indices:
assert num_blocks_np is not None
assert seq_lens_np is not None
......@@ -1083,9 +1085,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
qo_indptr_prefill_gpu = (
qo_indptr[prefill_start:] - qo_indptr[prefill_start]
)
# Compute cum_seq_lens_kv on GPU to avoid CPU sync.
# This is the cumulative sum of the number of KV cache
# blocks per prefill request.
prefill_seq_lens = seq_lens[prefill_start:]
num_blocks_per_req = (prefill_seq_lens + page_size - 1) // page_size
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
prefill_start : num_reqs + 1
]
paged_kv_indptr_prefill_gpu[0] = 0
torch.cumsum(
num_blocks_per_req,
dim=0,
out=paged_kv_indptr_prefill_gpu[1:],
)
# Compute max_q_len for prefill requests
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment