Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
469b3ffa
Unverified
Commit
469b3ffa
authored
Aug 05, 2025
by
Giancarlo Delfin
Committed by
GitHub
Aug 05, 2025
Browse files
[V1] port xformers backend to v1 (#21342)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@meta.com
>
parent
ae87ddd0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
438 additions
and
1 deletion
+438
-1
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+4
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-0
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+0
-1
vllm/v1/attention/backends/xformers.py
vllm/v1/attention/backends/xformers.py
+430
-0
No files found.
tests/v1/attention/utils.py
View file @
469b3ffa
...
@@ -128,6 +128,8 @@ def get_attention_backend(backend_name: _Backend):
...
@@ -128,6 +128,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
,
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
,
_Backend
.
TREE_ATTN
:
_Backend
.
TREE_ATTN
:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
,
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
,
_Backend
.
XFORMERS_VLLM_V1
:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
,
}
}
if
backend_name
not
in
backend_map
:
if
backend_name
not
in
backend_map
:
...
...
vllm/engine/arg_utils.py
View file @
469b3ffa
...
@@ -1469,6 +1469,7 @@ class EngineArgs:
...
@@ -1469,6 +1469,7 @@ class EngineArgs:
"TORCH_SDPA_VLLM_V1"
,
"TORCH_SDPA_VLLM_V1"
,
"FLEX_ATTENTION"
,
"FLEX_ATTENTION"
,
"TREE_ATTN"
,
"TREE_ATTN"
,
"XFORMERS_VLLM_V1"
,
]
]
if
(
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
if
(
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
and
envs
.
VLLM_ATTENTION_BACKEND
not
in
V1_BACKENDS
):
and
envs
.
VLLM_ATTENTION_BACKEND
not
in
V1_BACKENDS
):
...
...
vllm/platforms/cuda.py
View file @
469b3ffa
...
@@ -271,6 +271,7 @@ class CudaPlatformBase(Platform):
...
@@ -271,6 +271,7 @@ class CudaPlatformBase(Platform):
TRITON_ATTN_VLLM_V1
=
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
# noqa: E501
TRITON_ATTN_VLLM_V1
=
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
# noqa: E501
FLASH_ATTN_V1
=
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
# noqa: E501
FLASH_ATTN_V1
=
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
# noqa: E501
TREE_ATTN_V1
=
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
# noqa: E501
TREE_ATTN_V1
=
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
# noqa: E501
XFORMERS_V1
=
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
# noqa: E501
if
selected_backend
==
_Backend
.
FLASHINFER
:
if
selected_backend
==
_Backend
.
FLASHINFER
:
logger
.
info_once
(
"Using FlashInfer backend on V1 engine."
)
logger
.
info_once
(
"Using FlashInfer backend on V1 engine."
)
...
@@ -291,6 +292,9 @@ class CudaPlatformBase(Platform):
...
@@ -291,6 +292,9 @@ class CudaPlatformBase(Platform):
elif
selected_backend
==
_Backend
.
TREE_ATTN
:
elif
selected_backend
==
_Backend
.
TREE_ATTN
:
logger
.
info_once
(
"Using Tree Attention backend on V1 engine."
)
logger
.
info_once
(
"Using Tree Attention backend on V1 engine."
)
return
TREE_ATTN_V1
return
TREE_ATTN_V1
elif
selected_backend
==
_Backend
.
XFORMERS_VLLM_V1
:
logger
.
info_once
(
"Using XFormers backend on V1 engine."
)
return
XFORMERS_V1
from
vllm.attention.selector
import
is_attn_backend_supported
from
vllm.attention.selector
import
is_attn_backend_supported
...
...
vllm/platforms/interface.py
View file @
469b3ffa
...
@@ -63,6 +63,7 @@ class _Backend(enum.Enum):
...
@@ -63,6 +63,7 @@ class _Backend(enum.Enum):
NO_ATTENTION
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
FLEX_ATTENTION
=
enum
.
auto
()
FLEX_ATTENTION
=
enum
.
auto
()
TREE_ATTN
=
enum
.
auto
()
TREE_ATTN
=
enum
.
auto
()
XFORMERS_VLLM_V1
=
enum
.
auto
()
class
PlatformEnum
(
enum
.
Enum
):
class
PlatformEnum
(
enum
.
Enum
):
...
...
vllm/v1/attention/backends/tree_attn.py
View file @
469b3ffa
...
@@ -316,7 +316,6 @@ class TreeAttentionImpl(AttentionImpl):
...
@@ -316,7 +316,6 @@ class TreeAttentionImpl(AttentionImpl):
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
...
vllm/v1/attention/backends/xformers.py
0 → 100644
View file @
469b3ffa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with XFormersAttention."""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
reorder_batch_to_split_decodes_and_prefills
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
try
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
AttentionBias
,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
)
XFORMERS_AVAILABLE
=
True
except
ImportError
:
XFORMERS_AVAILABLE
=
False
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
class
XFormersAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
40
,
48
,
56
,
64
,
72
,
80
,
88
,
96
,
104
,
112
,
120
,
128
,
136
,
144
,
152
,
160
,
168
,
176
,
184
,
192
,
200
,
208
,
216
,
224
,
232
,
240
,
248
,
256
,
]
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
supported_head_sizes
=
cls
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
attn_type
=
cls
.
__name__
.
removesuffix
(
"Backend"
)
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Supported head sizes are:
{
supported_head_sizes
}
. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@
staticmethod
def
get_name
()
->
str
:
return
"XFORMERS_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
type
[
"XFormersAttentionImpl"
]:
return
XFormersAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
type
[
"AttentionMetadata"
]:
return
XFormersAttentionMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_builder_cls
()
->
type
[
"XFormersAttentionMetadataBuilder"
]:
return
XFormersAttentionMetadataBuilder
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
@
dataclass
class
XFormersAttentionMetadata
:
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_prefill_tokens
:
int
=
0
num_decode_tokens
:
int
=
0
num_prefills
:
int
=
0
num_decodes
:
int
=
0
# Biases for different attention types.
attn_bias
:
Optional
[
"AttentionBias"
]
=
None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata
:
Optional
[
"XFormersAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
# Recover cached prefill-phase attention
# metadata structure
return
self
.
_cached_prefill_metadata
q_start_loc
=
self
.
query_start_loc
[
self
.
num_decodes
:]
q_seqlens
=
torch
.
diff
(
q_start_loc
)
kv_seqlens
=
self
.
seq_lens
[
self
.
num_decodes
:]
# Construct & cache prefill-phase attention metadata structure
self
.
_cached_prefill_metadata
=
XFormersAttentionMetadata
(
num_actual_tokens
=
self
.
num_prefill_tokens
,
max_query_len
=
int
(
q_seqlens
.
max
().
item
()),
query_start_loc
=
q_start_loc
-
q_start_loc
[
0
],
max_seq_len
=
int
(
kv_seqlens
.
max
().
item
()),
seq_lens
=
kv_seqlens
,
block_table
=
self
.
block_table
[
self
.
num_decodes
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_decode_tokens
:],
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"XFormersAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
# Recover cached decode-phase attention
# metadata structure
return
self
.
_cached_decode_metadata
q_start_loc
=
self
.
query_start_loc
q_seqlens
=
torch
.
diff
(
q_start_loc
)
decode_kv_seqlens
=
self
.
seq_lens
[:
self
.
num_decodes
]
# Construct & cache decode-phase attention metadata structure
self
.
_cached_decode_metadata
=
XFormersAttentionMetadata
(
num_actual_tokens
=
self
.
num_decode_tokens
,
max_query_len
=
int
(
q_seqlens
[:
self
.
num_decodes
].
max
().
item
()),
query_start_loc
=
q_start_loc
[:
self
.
num_decodes
+
1
],
max_seq_len
=
int
(
decode_kv_seqlens
.
max
().
item
()),
seq_lens
=
decode_kv_seqlens
,
block_table
=
self
.
block_table
[:
self
.
num_decodes
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_decode_tokens
],
attn_bias
=
self
.
attn_bias
,
)
return
self
.
_cached_decode_metadata
class
XFormersAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
XFormersAttentionMetadata
]):
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
assert
XFORMERS_AVAILABLE
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
_num_decodes
=
0
self
.
_num_decode_tokens
=
0
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
reorder_batch_to_split_decodes_and_prefills
(
input_batch
,
scheduler_output
,
decode_threshold
=
1
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
XFormersAttentionMetadata
:
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
1
))
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
q_start_loc
=
common_attn_metadata
.
query_start_loc
q_seqlens
=
torch
.
diff
(
q_start_loc
)
max_query_len
=
common_attn_metadata
.
max_query_len
kv_seqlens
=
common_attn_metadata
.
seq_lens
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
max
())
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
bias
=
None
if
num_decodes
>
0
:
# Construct the decoder bias.
decode_q_seqlens
=
q_seqlens
[:
num_decodes
]
decode_kv_seqlens
=
kv_seqlens
[:
num_decodes
]
bias
=
(
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
.
from_seqlens
(
q_seqlen
=
decode_q_seqlens
.
tolist
(),
kv_seqlen
=
decode_kv_seqlens
.
tolist
(),
page_size
=
self
.
block_size
,
block_tables
=
block_table
[:
num_decodes
],
device
=
block_table
.
device
,
))
return
XFormersAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_prefills
=
num_prefills
,
num_decodes
=
num_decodes
,
max_query_len
=
max_query_len
,
query_start_loc
=
q_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
kv_seqlens
,
block_table
=
block_table
,
slot_mapping
=
slot_mapping
,
attn_bias
=
bias
,
)
class
XFormersAttentionImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"XFormers does not support alibi slopes yet."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
None
:
self
.
sliding_window
=
(
-
1
,
-
1
)
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
if
logits_soft_cap
is
None
:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
XFormersAttentionBackend
.
validate_head_size
(
head_size
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"XFormersAttentionImpl."
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
XFormersAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with XFormers.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl"
)
if
attn_metadata
is
None
:
# Profiling run.
return
output
# Cache the input KVs.
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
descale_shape
=
(
prefill_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
unified_attention
(
q
=
query
[
num_decode_tokens
:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[
num_decode_tokens
:
num_actual_tokens
],
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
prefill_meta
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[:
num_decode_tokens
]
# Reshape query to [1, B_T, G, H, D].
q
=
decode_query
.
view
(
1
,
-
1
,
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
self
.
head_size
)
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
cache_k
=
key_cache
.
view
(
1
,
-
1
,
self
.
num_kv_heads
,
1
,
self
.
head_size
).
expand
(
1
,
-
1
,
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
self
.
head_size
,
)
cache_v
=
value_cache
.
view
(
1
,
-
1
,
self
.
num_kv_heads
,
1
,
self
.
head_size
).
expand
(
1
,
-
1
,
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
self
.
head_size
,
)
attn_bias
=
decode_meta
.
attn_bias
output
[:
num_decode_tokens
]
=
xops
.
memory_efficient_attention_forward
(
q
,
cache_k
,
cache_v
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
).
view
(
decode_query
.
shape
)
# Reshape the output tensor.
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment