Unverified Commit 9040cd40 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[DSV3.2][MTP] Optimize Indexer MTP handling (#36723)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent fa0d353a
......@@ -384,12 +384,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
seq_lens - decode_lens, decode_lens, output_size=actual_expanded
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
common_attn_metadata.query_start_loc[:num_decodes],
decode_lens,
output_size=actual_expanded,
)
# [0, 1, 2, 0, 0, 1, 2, 3]
......@@ -407,7 +409,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment