Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d766741
Unverified
Commit
0d766741
authored
Jan 07, 2026
by
Matthew Bonanni
Committed by
GitHub
Jan 08, 2026
Browse files
[0/N][Attention] Fix miscellaneous pre-commit issues (#31924)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
5dcd7ef1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
35 additions
and
25 deletions
+35
-25
vllm/attention/layers/static_sink_attention.py
vllm/attention/layers/static_sink_attention.py
+1
-1
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+5
-2
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+6
-2
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+2
-2
vllm/attention/ops/rocm_aiter_mla_sparse.py
vllm/attention/ops/rocm_aiter_mla_sparse.py
+3
-3
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+1
-4
vllm/attention/ops/triton_prefill_attention.py
vllm/attention/ops/triton_prefill_attention.py
+3
-3
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+14
-8
No files found.
vllm/attention/layers/static_sink_attention.py
View file @
0d766741
...
@@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp):
...
@@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp):
head_size
,
dtype
,
kv_cache_dtype
,
block_size
head_size
,
dtype
,
kv_cache_dtype
,
block_size
)
)
attn_backend
=
create_static_sink_attention_backend
(
attn_backend
=
create_static_sink_attention_backend
(
underlying_attn_backend
,
underlying_attn_backend
,
# type: ignore[arg-type]
sink_len
=
sink_len
,
sink_len
=
sink_len
,
)
)
Attention
.
__init__
(
Attention
.
__init__
(
...
...
vllm/attention/ops/flashmla.py
View file @
0d766741
...
@@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]:
...
@@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]:
is_availble
,
maybe_reason
=
_is_flashmla_available
()
is_availble
,
maybe_reason
=
_is_flashmla_available
()
if
not
is_availble
:
if
not
is_availble
:
return
False
,
maybe_reason
return
False
,
maybe_reason
if
current_platform
.
get
_device_capability
()[
0
]
!=
9
:
if
not
current_platform
.
is
_device_capability
_family
(
90
)
:
return
False
,
"FlashMLA Dense is only supported on Hopper devices."
return
False
,
"FlashMLA Dense is only supported on Hopper devices."
return
True
,
None
return
True
,
None
...
@@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
...
@@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
is_availble
,
maybe_reason
=
_is_flashmla_available
()
is_availble
,
maybe_reason
=
_is_flashmla_available
()
if
not
is_availble
:
if
not
is_availble
:
return
False
,
maybe_reason
return
False
,
maybe_reason
if
current_platform
.
get_device_capability
()[
0
]
not
in
(
9
,
10
):
if
not
(
current_platform
.
is_device_capability_family
(
90
)
or
current_platform
.
is_device_capability_family
(
100
)
):
return
(
return
(
False
,
False
,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices."
,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices."
,
...
...
vllm/attention/ops/paged_attn.py
View file @
0d766741
...
@@ -7,9 +7,13 @@ import torch
...
@@ -7,9 +7,13 @@ import torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
ops
=
_custom_ops
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
ops
=
ipex_ops
class
PagedAttention
:
class
PagedAttention
:
...
...
vllm/attention/ops/prefix_prefill.py
View file @
0d766741
...
@@ -754,8 +754,8 @@ def context_attention_fwd(
...
@@ -754,8 +754,8 @@ def context_attention_fwd(
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
extra_kargs
=
{
"kpack"
:
1
,
"waves_per_eu"
:
2
}
extra_kargs
=
{
"kpack"
:
1
,
"waves_per_eu"
:
2
}
grid
=
lambda
META
:
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
META
[
"BLOCK_M"
]))
grid
_fn
=
lambda
META
:
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
META
[
"BLOCK_M"
]))
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
_fn
](
q
,
q
,
k
,
k
,
v
,
v
,
...
...
vllm/attention/ops/rocm_aiter_mla_sparse.py
View file @
0d766741
...
@@ -37,9 +37,9 @@ def fp8_mqa_logits_torch(
...
@@ -37,9 +37,9 @@ def fp8_mqa_logits_torch(
Returns:
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
"""
k
v
,
scale
=
kv
k
_fp8
,
scale
=
kv
seq_len_kv
=
k
v
.
shape
[
0
]
seq_len_kv
=
k
_fp8
.
shape
[
0
]
k
=
k
v
.
to
(
torch
.
bfloat16
)
k
=
k
_fp8
.
to
(
torch
.
bfloat16
)
q
=
q
.
to
(
torch
.
bfloat16
)
q
=
q
.
to
(
torch
.
bfloat16
)
mask_lo
=
(
mask_lo
=
(
...
...
vllm/attention/ops/triton_decode_attention.py
View file @
0d766741
...
@@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1(
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
split_kv_id
=
tl
.
program_id
(
2
)
if
kv_group_num
>
BLOCK_H
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
if
kv_group_num
>
BLOCK_H
else
kv_group_num
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
...
...
vllm/attention/ops/triton_prefill_attention.py
View file @
0d766741
...
@@ -202,9 +202,9 @@ def _fwd_kernel(
...
@@ -202,9 +202,9 @@ def _fwd_kernel(
def
get_block_size
(
dtype
:
torch
.
dtype
)
->
int
:
def
get_block_size
(
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
return
32
return
32
elif
(
elif
current_platform
.
is_cuda_alike
()
and
current_platform
.
has_device_capability
(
current_platform
.
is_cuda_alike
()
80
)
and
current_platform
.
get_device_capability
().
major
>
8
:
):
return
128
return
128
else
:
else
:
return
64
return
64
...
...
vllm/attention/utils/fa_utils.py
View file @
0d766741
...
@@ -7,16 +7,23 @@ from vllm.platforms import current_platform
...
@@ -7,16 +7,23 @@ from vllm.platforms import current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
ops
=
_custom_ops
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
,
get_scheduler_metadata
from
vllm.vllm_flash_attn
import
(
# type: ignore[attr-defined]
flash_attn_varlen_func
,
get_scheduler_metadata
,
)
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
ops
=
ipex_ops
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
flash_attn_varlen_func
=
ops
.
flash_attn_varlen_func
flash_attn_varlen_func
=
ops
.
flash_attn_varlen_func
get_scheduler_metadata
=
ops
.
get_scheduler_metadata
get_scheduler_metadata
=
ops
.
get_scheduler_metadata
elif
current_platform
.
is_rocm
():
elif
current_platform
.
is_rocm
():
try
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
...
@@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
...
@@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def
flash_attn_supports_fp8
()
->
bool
:
def
flash_attn_supports_fp8
()
->
bool
:
return
(
return
(
get_flash_attn_version
()
==
3
get_flash_attn_version
()
==
3
and
current_platform
.
get
_device_capability
().
major
==
9
and
current_platform
.
is
_device_capability
_family
(
90
)
)
)
...
@@ -105,10 +112,9 @@ def flash_attn_supports_mla():
...
@@ -105,10 +112,9 @@ def flash_attn_supports_mla():
is_fa_version_supported
,
is_fa_version_supported
,
)
)
return
(
return
is_fa_version_supported
(
is_fa_version_supported
(
3
)
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
and
current_platform
.
is_device_capability_family
(
90
)
)
except
(
ImportError
,
AssertionError
):
except
(
ImportError
,
AssertionError
):
pass
pass
return
False
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment