Commit d89f7579 authored by zhuwenwen's avatar zhuwenwen
Browse files

update paged_mqa_logits

parent 30399801
...@@ -691,6 +691,8 @@ def sparse_attn_indexer( ...@@ -691,6 +691,8 @@ def sparse_attn_indexer(
max_model_len=max_model_len, max_model_len=max_model_len,
) )
else: else:
padded_q_fp8_decode_tokens = padded_q_fp8_decode_tokens.half
kv_cache = kv_cache.half
logits = gemmopt.paged_mqa_logits( logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens, padded_q_fp8_decode_tokens,
kv_cache, kv_cache,
...@@ -698,7 +700,7 @@ def sparse_attn_indexer( ...@@ -698,7 +700,7 @@ def sparse_attn_indexer(
decode_metadata.seq_lens, decode_metadata.seq_lens,
decode_metadata.block_table, decode_metadata.block_table,
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_model_len=max_model_len, max_context_len=max_model_len,
) )
# padded query len # padded query len
current_device = padded_q_fp8_decode_tokens.device current_device = padded_q_fp8_decode_tokens.device
......
...@@ -312,7 +312,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -312,7 +312,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if current_platform.is_rocm(): if current_platform.is_rocm():
self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer = gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms) seq_lens, self.kv_cache_spec.block_size, self.num_sms)
else: else:
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment