Unverified Commit c188749b authored by Chuan (Richard) Li's avatar Chuan (Richard) Li Committed by GitHub
Browse files

[ROCm] Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5/Linear) (#35850)


Signed-off-by: default avatarLi <chuali@amd.com>
parent 225d1090
...@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
**mla_args, **mla_args,
) )
assert num_heads == 16 or num_heads == 128, ( _valid_heads = num_heads in (4, 8) or (
f"Aiter MLA only supports 16 or 128 number of heads.\n" num_heads % 16 == 0 and 16 <= num_heads <= 128
)
assert _valid_heads, (
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f"in [16, 128].\n"
f"Provided {num_heads} number of heads.\n" f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value." "Try adjusting tensor_parallel_size value."
) )
self._needs_head_repeat = num_heads < 16
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features): if any(unsupported_features):
raise NotImplementedError( raise NotImplementedError(
...@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor) assert isinstance(q, torch.Tensor)
B = q.shape[0] B = q.shape[0]
if self._needs_head_repeat:
q = q.repeat_interleave(self._head_repeat_factor, dim=1)
kernel_num_heads = 16
else:
kernel_num_heads = self.num_heads
o = torch.zeros( o = torch.zeros(
B, B,
self.num_heads, kernel_num_heads,
self.kv_lora_rank, self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype, dtype=attn_metadata.decode.attn_out_dtype,
device=q.device, device=q.device,
...@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_scale=layer._k_scale, kv_scale=layer._k_scale,
) )
if self._needs_head_repeat:
o = o[:, :: self._head_repeat_factor, :]
return o, None return o, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment