Unverified Commit 095d2f87 authored by qli88's avatar qli88 Committed by GitHub
Browse files

[Bug] Fix GLM-5.1 running error on ROCm platform (#40763)


Signed-off-by: default avatarQiang Li <qiang.li2@amd.com>
parent 21792520
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Final
import torch
......@@ -389,6 +389,53 @@ def _expand_page_indices_kernel(
)
class AiterMLAHelper:
"""
AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and
16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.
"""
_AITER_MIN_MLA_HEADS: Final = 16
@staticmethod
def check_num_heads_validity(num_heads: int):
assert AiterMLAHelper.is_valid_num_heads(num_heads), (
f"Aiter MLA requires that num_heads be multiples or divisors of 16, "
f"but provided {num_heads} number of heads.\n"
f"Try adjusting tensor_parallel_size value."
)
@staticmethod
def is_valid_num_heads(num_heads: int) -> bool:
return (
num_heads % AiterMLAHelper._AITER_MIN_MLA_HEADS == 0
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
else AiterMLAHelper._AITER_MIN_MLA_HEADS % num_heads == 0
)
@staticmethod
def get_actual_mla_num_heads(num_heads: int) -> int:
return max(num_heads, AiterMLAHelper._AITER_MIN_MLA_HEADS)
@staticmethod
def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor:
return (
q
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
else q.repeat_interleave(
AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, dim=1
)
)
@staticmethod
def get_mla_unpadded_o(num_heads: int, o: torch.Tensor) -> torch.Tensor:
return (
o
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
else o[:, :: AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, :]
)
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
......@@ -418,17 +465,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name,
**mla_args,
)
_valid_heads = num_heads in (4, 8) or (
num_heads % 16 == 0 and 16 <= num_heads <= 128
)
assert _valid_heads, (
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f"in [16, 128].\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)
self._needs_head_repeat = num_heads < 16
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
AiterMLAHelper.check_num_heads_validity(num_heads)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
......@@ -471,15 +509,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
if self._needs_head_repeat:
q = q.repeat_interleave(self._head_repeat_factor, dim=1)
kernel_num_heads = 16
else:
kernel_num_heads = self.num_heads
mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
o = torch.empty(
B,
kernel_num_heads,
mla_num_heads,
self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype,
device=q.device,
......@@ -506,7 +540,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
)
rocm_aiter_ops.mla_decode_fwd(
q,
mla_padded_q,
kv_buffer,
o,
self.scale,
......@@ -518,7 +552,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
**mla_kwargs,
)
if self._needs_head_repeat:
o = o[:, :: self._head_repeat_factor, :]
return o, None
return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, o), None
......@@ -28,6 +28,9 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
AiterMLAHelper,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
......@@ -299,6 +302,8 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
AiterMLAHelper.check_num_heads_validity(num_heads)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......@@ -317,8 +322,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
[num_tokens, mla_num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
......@@ -344,7 +350,7 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
attn_metadata.paged_kv_last_page_len,
)
return output[:, : self.num_heads, :]
return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output)
def forward_mqa(
self,
......@@ -374,8 +380,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
)
return attn_out, None
......@@ -339,6 +339,8 @@ def rocm_fp8_paged_mqa_logits(
device="cuda",
dtype=torch.float32,
)
# TODO: 1. Replace _stage1 and out_qk.sum with another fused variant;
# 2. Remove ChunkQ when AITER PR #2891 merged
deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8,
kv_cache_fp8,
......@@ -347,6 +349,7 @@ def rocm_fp8_paged_mqa_logits(
context_lens,
block_tables,
max_model_len,
ChunkQ=heads,
)
return out_qk.sum(dim=0)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment