Unverified Commit 1b117cb0 authored by wufann's avatar wufann Committed by GitHub
Browse files

[ROCm] Fix aiter persistent mode mla with q/o nhead<16 for kimi-k2.5 tp8 (#38615)


Signed-off-by: default avatarwufann <36477220+wufann@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent abebd932
......@@ -129,9 +129,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
from aiter import dtypes, get_mla_metadata_info_v1
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
# For num_attention_heads < 16 (e.g. kimi-k2.5 head=8 with TP8),
# make sure get_mla_metadata_info_v1 / get_mla_metadata_v1 are consistent
# with the actual tensor shape passed to mla_decode_fwd.
self._num_attention_heads = max(16, self.num_heads)
q_dtype = self.decode_attn_out_dtype
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment