"lib/runtime/src/vscode:/vscode.git/clone" did not exist on "b0959cfdbce3dd11273449e8c4a89d3506640f09"
Unverified Commit 9f39b380 authored by Rishi Puri's avatar Rishi Puri Committed by GitHub
Browse files

[Bugfix] Fix spec decode test failures on Blackwell (SM100+) (#39546)


Signed-off-by: default avatarStefano Castagnetta <scastagnetta@nvidia.com>
Signed-off-by: default avatarRishi Puri <puririshi98@berkeley.edu>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarStefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 9a6a66f3
...@@ -12,6 +12,17 @@ steps: ...@@ -12,6 +12,17 @@ steps:
commands: commands:
- pytest -v -s v1/e2e/spec_decode -k "eagle_correctness" - pytest -v -s v1/e2e/spec_decode -k "eagle_correctness"
- label: Spec Decode Eagle Nightly B200
timeout_in_minutes: 30
device: b200
optional: true
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- tests/v1/e2e/spec_decode/
commands:
- pytest -v -s v1/e2e/spec_decode -k "eagle_correctness"
- label: Spec Decode Speculators + MTP - label: Spec Decode Speculators + MTP
timeout_in_minutes: 30 timeout_in_minutes: 30
device: h200_18gb device: h200_18gb
...@@ -23,6 +34,18 @@ steps: ...@@ -23,6 +34,18 @@ steps:
commands: commands:
- pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness" - pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness"
- label: Spec Decode Speculators + MTP Nightly B200
timeout_in_minutes: 30
device: b200
optional: true
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- vllm/transformers_utils/configs/speculators/
- tests/v1/e2e/spec_decode/
commands:
- pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness"
- label: Spec Decode Ngram + Suffix - label: Spec Decode Ngram + Suffix
timeout_in_minutes: 30 timeout_in_minutes: 30
device: h200_18gb device: h200_18gb
...@@ -43,6 +66,17 @@ steps: ...@@ -43,6 +66,17 @@ steps:
commands: commands:
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference" - pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
- label: Spec Decode Draft Model Nightly B200
timeout_in_minutes: 30
device: b200
optional: true
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- tests/v1/e2e/spec_decode/
commands:
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
- label: DFlash Speculators Correctness - label: DFlash Speculators Correctness
timeout_in_minutes: 30 timeout_in_minutes: 30
device: h100 device: h100
......
...@@ -920,9 +920,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -920,9 +920,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and ( all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
num_decodes == 0 or decode_use_trtllm num_decodes == 0 or decode_use_trtllm
) )
is_only_trtllm_decode = num_prefills == 0 and (
num_decodes > 0 and decode_use_trtllm
)
if not all_uses_trtllm: if not all_uses_trtllm:
if self.has_sinks: if self.has_sinks:
...@@ -968,7 +965,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -968,7 +965,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Guard access to seq_lens_cpu, which may not always be needed # Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode. # and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode # When all attention (both prefill and decode) uses TRTLLM,
# seq_lens_cpu is not needed since TRTLLM paths use GPU tensors
# (block_tables, seq_lens) directly.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not all_uses_trtllm
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = ( num_blocks_np = (
...@@ -1006,7 +1006,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -1006,7 +1006,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_blocks_np -= num_common_kv_blocks num_blocks_np -= num_common_kv_blocks
# Compute paged_kv_indices if necessary # Compute paged_kv_indices if necessary
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode # paged_kv_indices is only needed for FlashInfer native paths;
# TRTLLM paths use block_tables directly on GPU.
needs_paged_kv_indices = use_cascade or not all_uses_trtllm
if needs_paged_kv_indices: if needs_paged_kv_indices:
assert num_blocks_np is not None assert num_blocks_np is not None
assert seq_lens_np is not None assert seq_lens_np is not None
...@@ -1083,9 +1085,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -1083,9 +1085,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
qo_indptr_prefill_gpu = ( qo_indptr_prefill_gpu = (
qo_indptr[prefill_start:] - qo_indptr[prefill_start] qo_indptr[prefill_start:] - qo_indptr[prefill_start]
) )
# Compute cum_seq_lens_kv on GPU to avoid CPU sync.
# This is the cumulative sum of the number of KV cache
# blocks per prefill request.
prefill_seq_lens = seq_lens[prefill_start:]
num_blocks_per_req = (prefill_seq_lens + page_size - 1) // page_size
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[ paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
prefill_start : num_reqs + 1 prefill_start : num_reqs + 1
] ]
paged_kv_indptr_prefill_gpu[0] = 0
torch.cumsum(
num_blocks_per_req,
dim=0,
out=paged_kv_indptr_prefill_gpu[1:],
)
# Compute max_q_len for prefill requests # Compute max_q_len for prefill requests
query_lens_prefill_cpu = ( query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment