Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71cd8926
Unverified
Commit
71cd8926
authored
Feb 15, 2026
by
Isotr0py
Committed by
GitHub
Feb 15, 2026
Browse files
[MM Encoder] Add Triton ViT attention backend (#32183)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
19fab441
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
178 additions
and
51 deletions
+178
-51
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+16
-2
vllm/model_executor/layers/attention/mm_encoder_attention.py
vllm/model_executor/layers/attention/mm_encoder_attention.py
+38
-0
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+5
-4
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+5
-4
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+5
-4
vllm/model_executor/models/paddleocr_vl.py
vllm/model_executor/models/paddleocr_vl.py
+2
-8
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+1
-9
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-0
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+2
-0
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+4
-11
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+2
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+19
-7
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-0
vllm/v1/attention/ops/vit_attn_wrappers.py
vllm/v1/attention/ops/vit_attn_wrappers.py
+77
-0
No files found.
tests/kernels/attention/test_mha_attn.py
View file @
71cd8926
...
@@ -17,7 +17,7 @@ from vllm.platforms import current_platform
...
@@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_default_torch_dtype
,
set_random_seed
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
...
@@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str):
...
@@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str):
attn
=
MMEncoderAttention
(
16
,
72
,
scale
=
1
)
attn
=
MMEncoderAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
with
(
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
set_default_torch_dtype
(
torch
.
float32
),
):
attn
=
MMEncoderAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
TRITON_ATTN
def
ref_attention
(
def
ref_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -153,7 +162,12 @@ def test_mha_attn_forward(
...
@@ -153,7 +162,12 @@ def test_mha_attn_forward(
v
,
v
,
scale
=
scale
,
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
torch
.
testing
.
assert_close
(
output
,
ref_output
)
tol_kwargs
=
(
dict
(
rtol
=
1e-3
,
atol
=
1e-3
)
if
attn
.
attn_backend
==
AttentionBackendEnum
.
TRITON_ATTN
else
{}
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
**
tol_kwargs
)
@
pytest
.
mark
.
parametrize
(
"var_seq_len"
,
VAR_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"var_seq_len"
,
VAR_SEQ_LENS
)
...
...
vllm/model_executor/layers/attention/mm_encoder_attention.py
View file @
71cd8926
...
@@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
...
@@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from
vllm.v1.attention.ops.vit_attn_wrappers
import
(
from
vllm.v1.attention.ops.vit_attn_wrappers
import
(
vit_flash_attn_wrapper
,
vit_flash_attn_wrapper
,
vit_torch_sdpa_wrapper
,
vit_torch_sdpa_wrapper
,
vit_triton_attn_wrapper
,
)
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp):
...
@@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp):
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
return
output
return
output
def
_forward_triton
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
# Only used for Flash Attention
)
->
torch
.
Tensor
:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert
(
cu_seqlens
is
not
None
and
max_seqlen
is
not
None
)
or
(
cu_seqlens
is
None
and
max_seqlen
is
None
),
"cu_seqlens and max_seqlen should be both set or both None."
bsz
,
q_len
=
query
.
size
()[:
2
]
kv_len
=
key
.
size
(
1
)
is_reshaped
=
query
.
dim
()
!=
4
query
,
key
,
value
=
self
.
view_qkv_to_4d
(
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
output
=
vit_triton_attn_wrapper
(
q
=
query
,
k
=
key
,
v
=
value
,
batch_size
=
bsz
,
scale
=
self
.
scale
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
if
is_reshaped
:
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
return
output
def
forward_native
(
def
forward_native
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp):
...
@@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
is_flash_attn_backend
:
if
self
.
is_flash_attn_backend
:
return
self
.
_forward_fa
(
query
,
key
,
value
,
cu_seqlens
,
max_seqlen
)
return
self
.
_forward_fa
(
query
,
key
,
value
,
cu_seqlens
,
max_seqlen
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TRITON_ATTN
:
return
self
.
_forward_triton
(
query
,
key
,
value
,
cu_seqlens
,
max_seqlen
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
return
self
.
_forward_sdpa
(
query
,
key
,
value
,
cu_seqlens
)
return
self
.
_forward_sdpa
(
query
,
key
,
value
,
cu_seqlens
)
else
:
else
:
...
...
vllm/model_executor/models/dots_ocr.py
View file @
71cd8926
...
@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module):
...
@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module):
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
int
|
None
:
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
int
|
None
:
max_seqlen
=
None
max_seqlen
=
None
if
(
if
self
.
attn_backend
in
{
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
AttentionBackendEnum
.
FLASH_ATTN
,
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
AttentionBackendEnum
.
ROCM_AITER_FA
,
):
AttentionBackendEnum
.
TRITON_ATTN
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
71cd8926
...
@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
...
@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
max_seqlen
=
None
max_seqlen
=
None
if
(
if
self
.
attn_backend
in
{
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
AttentionBackendEnum
.
FLASH_ATTN
,
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
AttentionBackendEnum
.
ROCM_AITER_FA
,
):
AttentionBackendEnum
.
TRITON_ATTN
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/glm4_1v.py
View file @
71cd8926
...
@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
)
->
torch
.
Tensor
|
None
:
max_seqlen
=
None
max_seqlen
=
None
if
(
if
self
.
attn_backend
in
{
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
AttentionBackendEnum
.
FLASH_ATTN
,
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
AttentionBackendEnum
.
ROCM_AITER_FA
,
):
AttentionBackendEnum
.
TRITON_ATTN
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/paddleocr_vl.py
View file @
71cd8926
...
@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module):
...
@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module):
head_size
=
head_dim
,
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
dtype
=
torch
.
get_default_dtype
(),
)
)
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"PaddleOCR-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
SiglipEncoderLayer
(
SiglipEncoderLayer
(
...
@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module):
...
@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module):
if
self
.
attn_backend
in
{
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
}:
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
71cd8926
...
@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
dtype
=
torch
.
get_default_dtype
(),
dtype
=
torch
.
get_default_dtype
(),
)
)
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"Qwen2.5-VL does not support
{
self
.
attn_backend
}
backend now."
)
with
set_model_tag
(
"Qwen2_5_VisionBlock"
,
is_encoder
=
True
):
with
set_model_tag
(
"Qwen2_5_VisionBlock"
,
is_encoder
=
True
):
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
[
...
@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
if
self
.
attn_backend
in
{
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
}:
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
71cd8926
...
@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module):
if
self
.
attn_backend
in
{
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
}:
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
71cd8926
...
@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
...
@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
if
self
.
attn_backend
in
{
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
}:
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
if
self
.
attn_backend
in
{
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
}:
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
71cd8926
...
@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
dtype
=
torch
.
get_default_dtype
(),
dtype
=
torch
.
get_default_dtype
(),
)
)
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"Qwen3-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
[
Qwen3_VisionBlock
(
Qwen3_VisionBlock
(
...
@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
max_seqlen
=
torch
.
zeros
([],
device
=
cu_seqlens
.
device
)
max_seqlen
=
torch
.
zeros
([],
device
=
cu_seqlens
.
device
)
if
(
if
self
.
attn_backend
in
(
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
AttentionBackendEnum
.
FLASH_ATTN
,
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
):
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
return
max_seqlen
...
...
vllm/model_executor/models/vision.py
View file @
71cd8926
...
@@ -108,7 +108,7 @@ def get_vit_attn_backend(
...
@@ -108,7 +108,7 @@ def get_vit_attn_backend(
multimodal_config
:
MultiModalConfig
|
None
=
(
multimodal_config
:
MultiModalConfig
|
None
=
(
model_config
.
multimodal_config
if
model_config
is
not
None
else
None
model_config
.
multimodal_config
if
model_config
is
not
None
else
None
)
)
except
AssertionError
:
except
(
AssertionError
,
AttributeError
)
:
multimodal_config
=
None
multimodal_config
=
None
attn_backend_override
=
(
attn_backend_override
=
(
...
@@ -134,7 +134,7 @@ def is_vit_use_data_parallel():
...
@@ -134,7 +134,7 @@ def is_vit_use_data_parallel():
multimodal_config
:
MultiModalConfig
|
None
=
(
multimodal_config
:
MultiModalConfig
|
None
=
(
model_config
.
multimodal_config
if
model_config
is
not
None
else
None
model_config
.
multimodal_config
if
model_config
is
not
None
else
None
)
)
except
AssertionError
:
except
(
AssertionError
,
AttributeError
)
:
multimodal_config
=
None
multimodal_config
=
None
mm_encoder_tp_mode
=
(
mm_encoder_tp_mode
=
(
...
...
vllm/platforms/cuda.py
View file @
71cd8926
...
@@ -411,8 +411,9 @@ class CudaPlatformBase(Platform):
...
@@ -411,8 +411,9 @@ class CudaPlatformBase(Platform):
@
classmethod
@
classmethod
def
get_supported_vit_attn_backends
(
cls
)
->
list
[
"AttentionBackendEnum"
]:
def
get_supported_vit_attn_backends
(
cls
)
->
list
[
"AttentionBackendEnum"
]:
return
[
return
[
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TRITON_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
]
]
@
classmethod
@
classmethod
...
@@ -430,14 +431,25 @@ class CudaPlatformBase(Platform):
...
@@ -430,14 +431,25 @@ class CudaPlatformBase(Platform):
logger
.
info_once
(
f
"Using backend
{
backend
}
for vit attention"
)
logger
.
info_once
(
f
"Using backend
{
backend
}
for vit attention"
)
return
backend
return
backend
# Try FlashAttention first
cc
=
cls
.
get_device_capability
()
if
(
cc
:
=
cls
.
get_device_capability
())
and
cc
.
major
>=
8
:
for
vit_attn_backend
in
cls
.
get_supported_vit_attn_backends
():
if
vit_attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
continue
try
:
try
:
backend_class
=
AttentionBackendEnum
.
FLASH_ATTN
.
get_class
()
backend_class
=
vit_attn_backend
.
get_class
()
i
f
backend_class
.
supports_head_size
(
i
s_backend_supported
=
backend_class
.
supports_head_size
(
head_size
head_size
)
and
backend_class
.
supports_dtype
(
dtype
):
)
and
backend_class
.
supports_dtype
(
dtype
)
return
AttentionBackendEnum
.
FLASH_ATTN
if
cc
is
not
None
:
is_backend_supported
=
(
is_backend_supported
and
backend_class
.
supports_compute_capability
(
cc
)
)
if
is_backend_supported
:
logger
.
info_once
(
f
"Using backend
{
vit_attn_backend
}
for vit attention"
)
return
vit_attn_backend
except
ImportError
:
except
ImportError
:
pass
pass
...
...
vllm/platforms/rocm.py
View file @
71cd8926
...
@@ -384,6 +384,7 @@ class RocmPlatform(Platform):
...
@@ -384,6 +384,7 @@ class RocmPlatform(Platform):
return
[
return
[
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
AttentionBackendEnum
.
TRITON_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
TORCH_SDPA
,
]
]
...
...
vllm/v1/attention/ops/vit_attn_wrappers.py
View file @
71cd8926
...
@@ -110,6 +110,83 @@ def vit_flash_attn_wrapper(
...
@@ -110,6 +110,83 @@ def vit_flash_attn_wrapper(
)
)
def
triton_attn_wrapper
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
batch_size
:
int
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
from
vllm.v1.attention.ops.triton_prefill_attention
import
context_attention_fwd
q_len
=
q
.
size
(
1
)
if
cu_seqlens
is
None
:
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
max_seqlen
=
q_len
if
max_seqlen
is
None
else
max_seqlen
.
item
()
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
torch
.
empty_like
(
q
)
context_attention_fwd
(
q
,
k
,
v
,
output
,
b_start_loc
=
cu_seqlens
[:
-
1
],
b_seq_len
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
],
max_input_len
=
max_seqlen
,
is_causal
=
False
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
softmax_scale
=
scale
,
)
context_layer
=
einops
.
rearrange
(
output
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
context_layer
def
triton_attn_wrapper_fake
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
batch_size
:
int
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
direct_register_custom_op
(
op_name
=
"triton_attn_wrapper"
,
op_func
=
triton_attn_wrapper
,
fake_impl
=
triton_attn_wrapper_fake
,
)
def
vit_triton_attn_wrapper
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
batch_size
:
int
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
triton_attn_wrapper
(
q
,
k
,
v
,
batch_size
,
scale
,
cu_seqlens
,
max_seqlen
,
)
def
apply_sdpa
(
def
apply_sdpa
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment