Unverified Commit 095d2f87 authored by qli88's avatar qli88 Committed by GitHub
Browse files

[Bug] Fix GLM-5.1 running error on ROCm platform (#40763)


Signed-off-by: default avatarQiang Li <qiang.li2@amd.com>
parent 21792520
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar, Final
import torch import torch
...@@ -389,6 +389,53 @@ def _expand_page_indices_kernel( ...@@ -389,6 +389,53 @@ def _expand_page_indices_kernel(
) )
class AiterMLAHelper:
"""
AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and
16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.
"""
_AITER_MIN_MLA_HEADS: Final = 16
@staticmethod
def check_num_heads_validity(num_heads: int):
assert AiterMLAHelper.is_valid_num_heads(num_heads), (
f"Aiter MLA requires that num_heads be multiples or divisors of 16, "
f"but provided {num_heads} number of heads.\n"
f"Try adjusting tensor_parallel_size value."
)
@staticmethod
def is_valid_num_heads(num_heads: int) -> bool:
return (
num_heads % AiterMLAHelper._AITER_MIN_MLA_HEADS == 0
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
else AiterMLAHelper._AITER_MIN_MLA_HEADS % num_heads == 0
)
@staticmethod
def get_actual_mla_num_heads(num_heads: int) -> int:
return max(num_heads, AiterMLAHelper._AITER_MIN_MLA_HEADS)
@staticmethod
def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor:
return (
q
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
else q.repeat_interleave(
AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, dim=1
)
)
@staticmethod
def get_mla_unpadded_o(num_heads: int, o: torch.Tensor) -> torch.Tensor:
return (
o
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
else o[:, :: AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, :]
)
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__( def __init__(
self, self,
...@@ -418,17 +465,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -418,17 +465,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
**mla_args, **mla_args,
) )
_valid_heads = num_heads in (4, 8) or ( AiterMLAHelper.check_num_heads_validity(num_heads)
num_heads % 16 == 0 and 16 <= num_heads <= 128
)
assert _valid_heads, (
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f"in [16, 128].\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)
self._needs_head_repeat = num_heads < 16
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features): if any(unsupported_features):
raise NotImplementedError( raise NotImplementedError(
...@@ -471,15 +509,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -471,15 +509,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor) assert isinstance(q, torch.Tensor)
B = q.shape[0] B = q.shape[0]
if self._needs_head_repeat: mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
q = q.repeat_interleave(self._head_repeat_factor, dim=1) mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
kernel_num_heads = 16
else:
kernel_num_heads = self.num_heads
o = torch.empty( o = torch.empty(
B, B,
kernel_num_heads, mla_num_heads,
self.kv_lora_rank, self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype, dtype=attn_metadata.decode.attn_out_dtype,
device=q.device, device=q.device,
...@@ -506,7 +540,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -506,7 +540,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
) )
rocm_aiter_ops.mla_decode_fwd( rocm_aiter_ops.mla_decode_fwd(
q, mla_padded_q,
kv_buffer, kv_buffer,
o, o,
self.scale, self.scale,
...@@ -518,7 +552,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -518,7 +552,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
**mla_kwargs, **mla_kwargs,
) )
if self._needs_head_repeat: return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, o), None
o = o[:, :: self._head_repeat_factor, :]
return o, None
...@@ -28,6 +28,9 @@ from vllm.v1.attention.backend import ( ...@@ -28,6 +28,9 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index, triton_convert_req_index_to_global_index,
) )
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
AiterMLAHelper,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -299,6 +302,8 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]) ...@@ -299,6 +302,8 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
indexer: "Indexer | None" = None, indexer: "Indexer | None" = None,
**mla_args, **mla_args,
) -> None: ) -> None:
AiterMLAHelper.check_num_heads_validity(num_heads)
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
...@@ -317,8 +322,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]) ...@@ -317,8 +322,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
attn_metadata: ROCMAiterMLASparseMetadata, attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = q.shape[0] num_tokens = q.shape[0]
mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
output = torch.empty( output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank], [num_tokens, mla_num_heads, self.kv_lora_rank],
dtype=q.dtype, dtype=q.dtype,
device=q.device, device=q.device,
) )
...@@ -344,7 +350,7 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]) ...@@ -344,7 +350,7 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
attn_metadata.paged_kv_last_page_len, attn_metadata.paged_kv_last_page_len,
) )
return output[:, : self.num_heads, :] return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output)
def forward_mqa( def forward_mqa(
self, self,
...@@ -374,8 +380,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]) ...@@ -374,8 +380,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
NUM_TOPK_TOKENS=attn_metadata.topk_tokens, NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
) )
mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
attn_out = self._forward_bf16_kv( attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
) )
return attn_out, None return attn_out, None
...@@ -339,6 +339,8 @@ def rocm_fp8_paged_mqa_logits( ...@@ -339,6 +339,8 @@ def rocm_fp8_paged_mqa_logits(
device="cuda", device="cuda",
dtype=torch.float32, dtype=torch.float32,
) )
# TODO: 1. Replace _stage1 and out_qk.sum with another fused variant;
# 2. Remove ChunkQ when AITER PR #2891 merged
deepgemm_fp8_paged_mqa_logits_stage1( deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8, q_fp8,
kv_cache_fp8, kv_cache_fp8,
...@@ -347,6 +349,7 @@ def rocm_fp8_paged_mqa_logits( ...@@ -347,6 +349,7 @@ def rocm_fp8_paged_mqa_logits(
context_lens, context_lens,
block_tables, block_tables,
max_model_len, max_model_len,
ChunkQ=heads,
) )
return out_qk.sum(dim=0) return out_qk.sum(dim=0)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment