Unverified Commit 5de6dd06 authored by Kebe's avatar Kebe Committed by GitHub
Browse files

[Bugfix] [DeepSeek-V3.2] fix sparse_attn_indexer padding (#32175)


Signed-off-by: default avatarKebe <mail@kebe7jun.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 70950255
...@@ -717,13 +717,20 @@ def sparse_attn_indexer( ...@@ -717,13 +717,20 @@ def sparse_attn_indexer(
# decode_threshold since we unstrictly split # decode_threshold since we unstrictly split
# prefill and decode by decode_threshold # prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens) # (currently set to 1 + speculative tokens)
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
padded_q_fp8_decode_tokens = pack_seq_triton( padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens q_fp8[:num_decode_tokens], decode_lens
) )
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens)
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
padded_weights = padded_weights.flatten(0, 1)
else: else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:] decode_lens.shape[0], -1, *q_fp8.shape[1:]
) )
padded_weights = weights
# TODO: move and optimize below logic with triton kernels # TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0] batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1] next_n = padded_q_fp8_decode_tokens.shape[1]
...@@ -739,14 +746,14 @@ def sparse_attn_indexer( ...@@ -739,14 +746,14 @@ def sparse_attn_indexer(
logits = fp8_paged_mqa_logits_func( logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens, padded_q_fp8_decode_tokens,
kv_cache, kv_cache,
weights[:num_padded_tokens], padded_weights[:num_padded_tokens],
decode_metadata.seq_lens, decode_metadata.seq_lens,
decode_metadata.block_table, decode_metadata.block_table,
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_model_len=max_model_len, max_model_len=max_model_len,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode( torch.ops._C.top_k_per_row_decode(
logits, logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment