Unverified Commit de717476 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[SpecDecode] Simplified alternative padded-speculation acceptance rate fix (#29845)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 95863540
......@@ -306,9 +306,15 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
)
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
......
......@@ -564,6 +564,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
# Don't try to access the runner on AMD
if self.aot_schedule:
......@@ -727,8 +728,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
......@@ -778,13 +779,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
......@@ -799,6 +794,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
......@@ -995,13 +992,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace
# over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_seq_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
......
......@@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
......@@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
......@@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits = 1
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
num_reqs=seq_lens_device.shape[0],
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,
......
......@@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
......
......@@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
......
......@@ -236,6 +236,7 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
......@@ -414,6 +415,17 @@ class EagleProposer:
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[: batch_size + 1]
).clone()
# In padded drafter batch, we need to adjust the sequence lengths
# to remove the "padding" (i.e. rejected tokens).
# Only apply this adjustment when we have rejected tokens
# (i.e., not the first proposal).
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
# Invalidate the CPU-side shadows to avoid H<>D sync.
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
......@@ -628,13 +640,14 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_reqs = common_attn_metadata.num_reqs
device = valid_sampled_tokens_count.device
......@@ -642,14 +655,17 @@ class EagleProposer:
token_indices_to_sample = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
num_rejected_tokens_gpu = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
# Kernel grid: one program per request (row)
grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens,
valid_sampled_tokens_count,
common_attn_metadata.query_start_loc,
token_indices_to_sample,
num_rejected_tokens_gpu,
num_reqs,
)
......@@ -674,7 +690,11 @@ class EagleProposer:
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)
return spec_common_attn_metadata, token_indices_to_sample
return (
spec_common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
)
def propose_tree(
self,
......
......@@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
......@@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit
......
......@@ -3534,6 +3534,7 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count
)
num_rejected_tokens_gpu = None
if spec_decode_metadata is None:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
......@@ -3564,13 +3565,15 @@ class GPUModelRunner(
else:
target_hidden_states = hidden_states[token_indices]
else:
common_attn_metadata, token_indices_to_sample = (
self.drafter.prepare_inputs_padded(
(
common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
) = self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count,
)
)
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range
target_token_ids = self.input_ids.gpu[:total_num_tokens]
......@@ -3600,6 +3603,7 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
return draft_token_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment