Commit d89f7579 authored by zhuwenwen's avatar zhuwenwen
Browse files

update paged_mqa_logits

parent 30399801
......@@ -691,6 +691,8 @@ def sparse_attn_indexer(
max_model_len=max_model_len,
)
else:
padded_q_fp8_decode_tokens = padded_q_fp8_decode_tokens.half
kv_cache = kv_cache.half
logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
......@@ -698,7 +700,7 @@ def sparse_attn_indexer(
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
max_context_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
......
......@@ -312,7 +312,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if current_platform.is_rocm():
self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata(
self.scheduler_metadata_buffer = gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment