Unverified Commit de717476 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[SpecDecode] Simplified alternative padded-speculation acceptance rate fix (#29845)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 95863540
...@@ -306,9 +306,15 @@ def test_prepare_inputs_padded(): ...@@ -306,9 +306,15 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens) proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded( output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
) )
)
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
assert output_metadata.max_query_len == 3 assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
......
...@@ -564,6 +564,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -564,6 +564,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size
# Don't try to access the runner on AMD # Don't try to access the runner on AMD
if self.aot_schedule: if self.aot_schedule:
...@@ -727,8 +728,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -727,8 +728,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,
...@@ -778,13 +779,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -778,13 +779,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills( split_decodes_and_prefills(
...@@ -799,6 +794,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -799,6 +794,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None prefill_metadata = None
if num_prefills > 0: if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
...@@ -995,13 +992,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -995,13 +992,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
dcp_tot_seq_lens_device = None dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes] dcp_tot_seq_lens_device = seq_lens[:num_decodes]
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
# This eliminates GPU->CPU sync while minimizing workspace
# over-allocation.
num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
max_seq_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
decode_metadata = self._build_decode( decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...], block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes], seq_lens_device=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
......
...@@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -169,8 +169,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,
...@@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -178,7 +178,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata: ) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item() max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph # For Flash Attention MLA + full cudagraph
max_num_splits = 0 max_num_splits = 0
...@@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -193,7 +192,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits = 1 max_num_splits = 1
scheduler_metadata = self._schedule_decode( scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(), num_reqs=seq_lens_device.shape[0],
cu_query_lens=query_start_loc_device, cu_query_lens=query_start_loc_device,
max_query_len=max_query_len, max_query_len=max_query_len,
seqlens=seq_lens_device, seqlens=seq_lens_device,
......
...@@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -143,8 +143,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,
......
...@@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -106,8 +106,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def _build_decode( def _build_decode(
self, self,
block_table_tensor: torch.Tensor, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, seq_lens_device: torch.Tensor,
max_seq_len: int,
query_start_loc_cpu: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor, query_start_loc_device: torch.Tensor,
num_decode_tokens: int, num_decode_tokens: int,
......
...@@ -236,6 +236,7 @@ class EagleProposer: ...@@ -236,6 +236,7 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
...@@ -414,6 +415,17 @@ class EagleProposer: ...@@ -414,6 +415,17 @@ class EagleProposer:
common_attn_metadata.query_start_loc_cpu = torch.from_numpy( common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[: batch_size + 1] self.token_arange_np[: batch_size + 1]
).clone() ).clone()
# In padded drafter batch, we need to adjust the sequence lengths
# to remove the "padding" (i.e. rejected tokens).
# Only apply this adjustment when we have rejected tokens
# (i.e., not the first proposal).
if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
# Invalidate the CPU-side shadows to avoid H<>D sync.
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1): for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
...@@ -628,13 +640,14 @@ class EagleProposer: ...@@ -628,13 +640,14 @@ class EagleProposer:
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata, spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor, valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor]: ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
""" """
This function is used to prepare the inputs for speculative decoding This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding, It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`. used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
""" """
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
device = valid_sampled_tokens_count.device device = valid_sampled_tokens_count.device
...@@ -642,14 +655,17 @@ class EagleProposer: ...@@ -642,14 +655,17 @@ class EagleProposer:
token_indices_to_sample = torch.empty( token_indices_to_sample = torch.empty(
(num_reqs,), dtype=torch.int32, device=device (num_reqs,), dtype=torch.int32, device=device
) )
num_rejected_tokens_gpu = torch.empty(
(num_reqs,), dtype=torch.int32, device=device
)
# Kernel grid: one program per request (row)
grid = (num_reqs,) grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid]( eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens, spec_decode_metadata.cu_num_draft_tokens,
valid_sampled_tokens_count, valid_sampled_tokens_count,
common_attn_metadata.query_start_loc, common_attn_metadata.query_start_loc,
token_indices_to_sample, token_indices_to_sample,
num_rejected_tokens_gpu,
num_reqs, num_reqs,
) )
...@@ -674,7 +690,11 @@ class EagleProposer: ...@@ -674,7 +690,11 @@ class EagleProposer:
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
) )
return spec_common_attn_metadata, token_indices_to_sample return (
spec_common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
)
def propose_tree( def propose_tree(
self, self,
......
...@@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel( ...@@ -23,6 +23,7 @@ def eagle_prepare_inputs_padded_kernel(
valid_sampled_tokens_count_ptr, # [num_reqs] valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1] query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output) token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32 num_reqs, # tl.int32
): ):
""" """
...@@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel( ...@@ -56,6 +57,7 @@ def eagle_prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected_tokens index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample) tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit @triton.jit
......
...@@ -3534,6 +3534,7 @@ class GPUModelRunner( ...@@ -3534,6 +3534,7 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
num_rejected_tokens_gpu = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
token_indices_to_sample = None token_indices_to_sample = None
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
...@@ -3564,13 +3565,15 @@ class GPUModelRunner( ...@@ -3564,13 +3565,15 @@ class GPUModelRunner(
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
else: else:
common_attn_metadata, token_indices_to_sample = ( (
self.drafter.prepare_inputs_padded( common_attn_metadata,
token_indices_to_sample,
num_rejected_tokens_gpu,
) = self.drafter.prepare_inputs_padded(
common_attn_metadata, common_attn_metadata,
spec_decode_metadata, spec_decode_metadata,
valid_sampled_tokens_count, valid_sampled_tokens_count,
) )
)
total_num_tokens = common_attn_metadata.num_actual_tokens total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range # When padding the batch, token_indices is just a range
target_token_ids = self.input_ids.gpu[:total_num_tokens] target_token_ids = self.input_ids.gpu[:total_num_tokens]
...@@ -3600,6 +3603,7 @@ class GPUModelRunner( ...@@ -3600,6 +3603,7 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
) )
return draft_token_ids return draft_token_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment