Unverified Commit 97a3d6d9 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Bugfix] Massage MLA's usage of flash attn for RoCM (#13310)

parent 579d7a63
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple from typing import Any, Dict, Generic, List, Optional, Tuple
...@@ -183,6 +184,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -183,6 +184,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.o_proj = o_proj self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O): if is_fp8(self.W_UV_O):
...@@ -487,7 +497,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -487,7 +497,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0) value=0)
attn_output = flash_attn_varlen_func( attn_output = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
v=v_padded, v=v_padded,
...@@ -497,7 +507,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -497,7 +507,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=max_prefill_seq_len, max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
fa_version=self.vllm_flash_attn_version,
) )
attn_output = attn_output\ attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment