Unverified Commit c188749b authored by Chuan (Richard) Li's avatar Chuan (Richard) Li Committed by GitHub
Browse files

[ROCm] Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5/Linear) (#35850)


Signed-off-by: default avatarLi <chuali@amd.com>
parent 225d1090
......@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name,
**mla_args,
)
assert num_heads == 16 or num_heads == 128, (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
_valid_heads = num_heads in (4, 8) or (
num_heads % 16 == 0 and 16 <= num_heads <= 128
)
assert _valid_heads, (
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f"in [16, 128].\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)
self._needs_head_repeat = num_heads < 16
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
......@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
if self._needs_head_repeat:
q = q.repeat_interleave(self._head_repeat_factor, dim=1)
kernel_num_heads = 16
else:
kernel_num_heads = self.num_heads
o = torch.zeros(
B,
self.num_heads,
kernel_num_heads,
self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype,
device=q.device,
......@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_scale=layer._k_scale,
)
if self._needs_head_repeat:
o = o[:, :: self._head_repeat_factor, :]
return o, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment