Unverified Commit 0d766741 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[0/N][Attention] Fix miscellaneous pre-commit issues (#31924)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 5dcd7ef1
...@@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp):
head_size, dtype, kv_cache_dtype, block_size head_size, dtype, kv_cache_dtype, block_size
) )
attn_backend = create_static_sink_attention_backend( attn_backend = create_static_sink_attention_backend(
underlying_attn_backend, underlying_attn_backend, # type: ignore[arg-type]
sink_len=sink_len, sink_len=sink_len,
) )
Attention.__init__( Attention.__init__(
......
...@@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]: ...@@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]:
is_availble, maybe_reason = _is_flashmla_available() is_availble, maybe_reason = _is_flashmla_available()
if not is_availble: if not is_availble:
return False, maybe_reason return False, maybe_reason
if current_platform.get_device_capability()[0] != 9: if not current_platform.is_device_capability_family(90):
return False, "FlashMLA Dense is only supported on Hopper devices." return False, "FlashMLA Dense is only supported on Hopper devices."
return True, None return True, None
...@@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]: ...@@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
is_availble, maybe_reason = _is_flashmla_available() is_availble, maybe_reason = _is_flashmla_available()
if not is_availble: if not is_availble:
return False, maybe_reason return False, maybe_reason
if current_platform.get_device_capability()[0] not in (9, 10): if not (
current_platform.is_device_capability_family(90)
or current_platform.is_device_capability_family(100)
):
return ( return (
False, False,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.", "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
......
...@@ -7,9 +7,13 @@ import torch ...@@ -7,9 +7,13 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops
ops = _custom_ops
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops
ops = ipex_ops
class PagedAttention: class PagedAttention:
......
...@@ -754,8 +754,8 @@ def context_attention_fwd( ...@@ -754,8 +754,8 @@ def context_attention_fwd(
if current_platform.is_rocm(): if current_platform.is_rocm():
extra_kargs = {"kpack": 1, "waves_per_eu": 2} extra_kargs = {"kpack": 1, "waves_per_eu": 2}
grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid]( _fwd_kernel[grid_fn](
q, q,
k, k,
v, v,
......
...@@ -37,9 +37,9 @@ def fp8_mqa_logits_torch( ...@@ -37,9 +37,9 @@ def fp8_mqa_logits_torch(
Returns: Returns:
Logits tensor of shape [M, N], dtype `torch.float32`. Logits tensor of shape [M, N], dtype `torch.float32`.
""" """
kv, scale = kv k_fp8, scale = kv
seq_len_kv = kv.shape[0] seq_len_kv = k_fp8.shape[0]
k = kv.to(torch.bfloat16) k = k_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16) q = q.to(torch.bfloat16)
mask_lo = ( mask_lo = (
......
...@@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1(
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2) split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H: VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num) mask_h = mask_h & (cur_head < q_head_num)
......
...@@ -202,9 +202,9 @@ def _fwd_kernel( ...@@ -202,9 +202,9 @@ def _fwd_kernel(
def get_block_size(dtype: torch.dtype) -> int: def get_block_size(dtype: torch.dtype) -> int:
if dtype == torch.float32: if dtype == torch.float32:
return 32 return 32
elif ( elif current_platform.is_cuda_alike() and current_platform.has_device_capability(
current_platform.is_cuda_alike() 80
) and current_platform.get_device_capability().major > 8: ):
return 128 return 128
else: else:
return 64 return 64
......
...@@ -7,16 +7,23 @@ from vllm.platforms import current_platform ...@@ -7,16 +7,23 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm import _custom_ops as ops from vllm import _custom_ops
ops = _custom_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
get_scheduler_metadata,
)
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops
ops = ipex_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata get_scheduler_metadata = ops.get_scheduler_metadata
elif current_platform.is_rocm(): elif current_platform.is_rocm():
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
...@@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
return ( return (
get_flash_attn_version() == 3 get_flash_attn_version() == 3
and current_platform.get_device_capability().major == 9 and current_platform.is_device_capability_family(90)
) )
...@@ -105,10 +112,9 @@ def flash_attn_supports_mla(): ...@@ -105,10 +112,9 @@ def flash_attn_supports_mla():
is_fa_version_supported, is_fa_version_supported,
) )
return ( return is_fa_version_supported(
is_fa_version_supported(3) 3
and current_platform.get_device_capability()[0] == 9 ) and current_platform.is_device_capability_family(90)
)
except (ImportError, AssertionError): except (ImportError, AssertionError):
pass pass
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment