Unverified Commit 8da6ae49 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][Bugfix] Fix `fa_version` argument error in...


[ROCm][Bugfix] Fix `fa_version` argument error in `flash_attn_maxseqlen_wrapper` for ROCm without aiter (#30909)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 30bb19a7
...@@ -28,7 +28,7 @@ def flash_attn_maxseqlen_wrapper( ...@@ -28,7 +28,7 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int, fa_version: int | None,
) -> torch.Tensor: ) -> torch.Tensor:
kwargs = {} kwargs = {}
if is_rocm_aiter: if is_rocm_aiter:
...@@ -36,7 +36,8 @@ def flash_attn_maxseqlen_wrapper( ...@@ -36,7 +36,8 @@ def flash_attn_maxseqlen_wrapper(
else: else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
kwargs["fa_version"] = fa_version if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
...@@ -62,7 +63,7 @@ def flash_attn_maxseqlen_wrapper_fake( ...@@ -62,7 +63,7 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int, fa_version: int | None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(q) return torch.empty_like(q)
...@@ -82,7 +83,7 @@ def vit_flash_attn_wrapper( ...@@ -82,7 +83,7 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int, fa_version: int | None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment