Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7c1b7f3
Commit
e7c1b7f3
authored
Sep 06, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.4-dtk24.04.1'
parents
7462218e
04c62b93
Changes
442
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2077 additions
and
241 deletions
+2077
-241
vllm/assets/image.py
vllm/assets/image.py
+39
-0
vllm/attention/__init__.py
vllm/attention/__init__.py
+3
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+45
-4
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+30
-6
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+210
-9
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+332
-34
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+369
-0
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+101
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+31
-16
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+110
-29
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+24
-11
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+237
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+419
-86
vllm/attention/layer.py
vllm/attention/layer.py
+27
-12
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+3
-2
vllm/attention/ops/blocksparse_attention/utils.py
vllm/attention/ops/blocksparse_attention/utils.py
+33
-7
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+4
-2
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+22
-12
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+12
-4
vllm/attention/selector.py
vllm/attention/selector.py
+26
-6
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
vllm/assets/image.py
0 → 100644
View file @
e7c1b7f3
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
Literal
from
PIL
import
Image
from
vllm.connections
import
global_http_connection
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
from
.base
import
get_cache_dir
@
lru_cache
def
get_air_example_data_2_asset
(
filename
:
str
)
->
Image
.
Image
:
"""
Download and open an image from
``s3://air-example-data-2/vllm_opensource_llava/``.
"""
image_directory
=
get_cache_dir
()
/
"air-example-data-2"
image_directory
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
image_path
=
image_directory
/
filename
if
not
image_path
.
exists
():
base_url
=
"https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
global_http_connection
.
download_file
(
f
"
{
base_url
}
/
{
filename
}
"
,
image_path
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
return
Image
.
open
(
image_path
)
@
dataclass
(
frozen
=
True
)
class
ImageAsset
:
name
:
Literal
[
"stop_sign"
,
"cherry_blossom"
]
@
property
def
pil_image
(
self
)
->
Image
.
Image
:
return
get_air_example_data_2_asset
(
f
"
{
self
.
name
}
.jpg"
)
vllm/attention/__init__.py
View file @
e7c1b7f3
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataBuilder
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
...
...
@@ -7,6 +8,7 @@ __all__ = [
"Attention"
,
"AttentionBackend"
,
"AttentionMetadata"
,
"AttentionMetadataBuilder"
,
"Attention"
,
"get_attn_backend"
,
]
vllm/attention/backends/abstract.py
View file @
e7c1b7f3
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
from
enum
import
Enum
,
auto
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
import
torch
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
class
AttentionType
(
Enum
):
DECODER
=
auto
()
# Decoder attention between previous layer Q/K/V
ENCODER
=
auto
()
# Encoder attention between previous layer Q/K/V
ENCODER_DECODER
=
auto
()
# Attention between dec. Q and enc. K/V
class
AttentionBackend
(
ABC
):
"""Abstract class for attention backends."""
...
...
@@ -21,9 +31,23 @@ class AttentionBackend(ABC):
@
staticmethod
@
abstractmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
get
_metadata
_cls
()
->
Type
[
"AttentionMetadata"
]
:
raise
NotImplementedError
@
classmethod
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
abstractmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
NotImplementedError
@
classmethod
def
make_metadata_builder
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadataBuilder"
:
return
cls
.
get_builder_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
abstractmethod
def
get_kv_cache_shape
(
...
...
@@ -99,6 +123,20 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
"""Abstract class for attention metadata builders."""
@
abstractmethod
def
__init__
(
self
,
input_builder
:
"ModelRunnerInputBuilderBase"
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
@
abstractmethod
...
...
@@ -112,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
raise
NotImplementedError
...
...
@@ -123,6 +162,8 @@ class AttentionImpl(ABC, Generic[T]):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
e7c1b7f3
...
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
...
@@ -90,8 +91,12 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
return
BlocksparseFlashAttentionImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"BlocksparseFlashAttentionMetadata"
:
return
BlocksparseFlashAttentionMetadata
(
*
args
,
**
kwargs
)
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
BlocksparseFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -244,6 +249,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
return
self
.
_cached_decode_metadata
class
BlocksparseFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
BlocksparseFlashAttentionMetadata
]):
_metadata_cls
=
BlocksparseFlashAttentionMetadata
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
@@ -272,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
assert
sliding_window
is
None
,
ValueError
(
"sliding_window is invalid for blocksparse attention."
)
assert
logits_soft_cap
is
None
,
ValueError
(
"logits_soft_cap is invalid for blocksparse attention."
)
if
"num_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_heads"
]
=
num_heads
...
...
@@ -327,7 +341,9 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -340,6 +356,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -361,7 +383,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
...
...
@@ -398,7 +421,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
=
self
.
tp_rank
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_vert_stride
=
self
.
vert_stride
,
...
...
vllm/attention/backends/flash_attn.py
View file @
e7c1b7f3
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.utils
import
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
class
FlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -25,8 +34,12 @@ class FlashAttentionBackend(AttentionBackend):
return
FlashAttentionImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"FlashAttentionMetadata"
:
return
FlashAttentionMetadata
(
*
args
,
**
kwargs
)
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -83,7 +96,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------
-
|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
...
...
@@ -184,6 +197,175 @@ class FlashAttentionMetadata(AttentionMetadata):
return
self
.
_cached_decode_metadata
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
class
FlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
@@ -220,9 +402,11 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"FlashAttention does not support block-sparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -233,6 +417,10 @@ class FlashAttentionImpl(AttentionImpl):
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -256,7 +444,9 @@ class FlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
...
...
@@ -269,8 +459,15 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
kv_scale
==
1.0
,
"kv_scale is not supported in FlashAttention."
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
...
...
@@ -292,6 +489,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
...
...
@@ -329,6 +528,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
...
...
@@ -348,6 +548,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
vllm/attention/backends/flashinfer.py
View file @
e7c1b7f3
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
import
flashinfer
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -22,8 +38,12 @@ class FlashInferBackend(AttentionBackend):
return
FlashInferImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"FlashInferMetadata"
:
return
FlashInferMetadata
(
*
args
,
**
kwargs
)
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashInferMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
return
FlashInferMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -40,14 +60,14 @@ class FlashInferBackend(AttentionBackend):
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
raise
NotImplementedError
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
raise
NotImplementedError
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
...
...
@@ -60,19 +80,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
max_prefill_seq_len
:
int
use_cuda_graph
:
bool
=
Fals
e
use_cuda_graph
:
bool
=
Tru
e
prefill_wrapper
:
Optional
[
BatchPrefillWithPagedKVCacheWrapper
]
=
None
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
# Metadata for the prefill stage
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
...
...
@@ -98,6 +115,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
def
__post_init__
(
self
):
# Refer to
...
...
@@ -109,13 +127,44 @@ class FlashInferMetadata(AttentionMetadata):
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"received
{
self
.
head_dim
}
."
)
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if
self
.
num_prefills
==
0
:
assert
self
.
num_decode_tokens
>
0
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
def
begin_forward
(
self
):
if
self
.
num_prefill_tokens
>
0
:
if
self
.
paged_kv_indices
is
None
:
return
assert
self
.
prefill_wrapper
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
assert
batch_size
>=
0
# The prefill stage does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
self
.
paged_kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
device
=
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
begin_forward
(
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
)
else
:
if
not
self
.
use_cuda_graph
:
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
assert
self
.
decode_wrapper
is
not
None
self
.
decode_wrapper
.
end_forward
()
self
.
decode_wrapper
.
begin_forward
(
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
...
...
@@ -133,8 +182,9 @@ class FlashInferMetadata(AttentionMetadata):
)
->
Dict
[
str
,
Any
]:
if
skip_fields
is
None
:
skip_fields
=
set
()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the
prefill/
decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields
.
add
(
'prefill_wrapper'
)
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
...
...
@@ -157,6 +207,234 @@ class FlashInferMetadata(AttentionMetadata):
return
self
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self
.
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
self
.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
computed_block_nums
=
inter_data
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
is_profile_run
=
is_block_tables_empty
(
block_tables
)
# Compute slot mapping.
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if
is_profile_run
:
return
block_table
=
block_tables
[
seq_id
]
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
List
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_len
.
extend
([
0
]
*
cuda_graph_pad_size
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_len
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
),
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
),
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
)
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
...
...
@@ -168,6 +446,8 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
...
...
@@ -180,6 +460,7 @@ class FlashInferImpl(AttentionImpl):
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -191,9 +472,17 @@ class FlashInferImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
FlashInferMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
assert
kv_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashInfer."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -215,12 +504,18 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
block_tables
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
is
None
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
...
...
@@ -235,16 +530,19 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
raise
NotImplementedError
(
"Prefix caching is not supported with flashinfer yet."
)
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
self
.
logits_soft_cap
,
causal
=
True
)
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
kv_cache
,
sm_scale
=
self
.
scale
,
)
logits_soft_cap
=
self
.
logits_soft_cap
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/ipex_attn.py
0 → 100644
View file @
e7c1b7f3
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
_PARTITION_SIZE
=
512
class
IpexAttnBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"ipex-attn"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"IpexAttnBackendImpl"
]:
return
IpexAttnBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
return
IpexAttnMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for IpexAttnBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
seq_lens
:
Optional
[
List
[
int
]]
seqlen_q
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
Optional
[
int
]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
return
None
@
property
def
decode_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
return
None
return
self
class
IpexAttnBackendImpl
(
AttentionImpl
[
IpexAttnMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"IPEX backend does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"IPEX backend does not support logits_soft_cap."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
def
split_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
1
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
self
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
ipex_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
None
,
dtype
=
query
.
dtype
)
attn_metadata
.
attn_bias
=
att_masks
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
,
device
=
query
.
device
)
ipex_ops
.
varlen_attention
(
query
,
key
,
value
,
output
,
attn_metadata
.
seqlen_q
,
attn_metadata
.
seqlen_q
,
attn_metadata
.
max_seqlen
,
attn_metadata
.
max_seqlen
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
)
else
:
# prefix-enabled attention
raise
RuntimeError
(
"IPEX backend doesn't support prefix decoding."
)
else
:
# Decoding run.
max_seq_len
=
attn_metadata
.
max_decode_seq_len
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
ipex_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ipex_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
alibi_slopes
.
device
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
,
device
=
alibi_slopes
.
device
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
seq_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
vllm/attention/backends/openvino.py
0 → 100644
View file @
e7c1b7f3
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
import
openvino
as
ov
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
class
OpenVINOAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"openvino"
@
staticmethod
def
get_impl_cls
():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise
NotImplementedError
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
raise
NotImplementedError
@
staticmethod
def
make_openvino_metadata
(
*
args
,
**
kwargs
)
->
"OpenVINOAttentionMetadata"
:
return
OpenVINOAttentionMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
ov
.
Tensor
,
dst_kv_cache
:
ov
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]],
src_to_dists
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
for
src
,
dst
in
src_to_dists
:
for
key_cache
,
value_cache
in
kv_caches
:
key_cache
.
data
[
dst
,
:]
=
key_cache
.
data
[
src
,
:]
value_cache
.
data
[
dst
,
:]
=
value_cache
.
data
[
src
,
:]
@
dataclass
class
OpenVINOAttentionMetadata
:
"""Metadata for OpenVINOAttentionBackend.
Basic terms used below:
- batch_size_in_sequences - total number of sequences to execute
- prompt_lens – per sequence size number of scheduled tokens
- batch_size_in_tokens = sum(prompt_lens)
- max_context_len = max(context_lens)
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)
- num_blocks – total number of blocks in block_indices
"""
# Describes past KV cache size for each sequence within a batch
# Shape: [batch_size_in_sequences]
# Type: i32
past_lens
:
torch
.
Tensor
# Describes start indices of input / speculative tokens from
# current sequences within a batch sequence
# Shape: [batch_size_in_sequences + 1]
# Type: i32
subsequence_begins
:
torch
.
Tensor
# Describes block tables for each sequence within a batch -
# indices along 0th dimension in key_cache and value_cache inputs
# Shape: [num_blocks]
# Type: i32
block_indices
:
torch
.
Tensor
# Describes block tables for each sequence within a batch -
# for i-th element, it is an index in block_indices with the
# first block belonging to i-th sequence
# Shape: [batch_size_in_sequences + 1]
# Type: i32
block_indices_begins
:
torch
.
Tensor
# Describes max context length
# Shape: scalar
# Type: i32
max_context_len
:
torch
.
Tensor
vllm/attention/backends/pallas.py
View file @
e7c1b7f3
...
...
@@ -3,10 +3,9 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
import
torch_xla.experimental.dynamo_set_buffer_donor
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
class
PallasAttentionBackend
(
AttentionBackend
):
...
...
@@ -16,8 +15,8 @@ class PallasAttentionBackend(AttentionBackend):
return
PallasAttentionBackendImpl
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"PallasMetadata"
:
return
PallasMetadata
(
*
args
,
**
kwargs
)
def
get
_metadata
_cls
()
->
Type
[
"PallasMetadata"
]
:
return
PallasMetadata
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -32,17 +31,22 @@ class PallasAttentionBackend(AttentionBackend):
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
raise
NotImplemented
Error
(
"swap_blocks is not
implemente
d."
)
raise
Runtime
Error
(
"swap_blocks is not
used for the TPU backen
d."
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
],
src_to_dists
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
# TODO(woosuk): Implement this.
raise
NotImplementedError
(
"copy_blocks is not implemented."
)
src_indices
,
dst_indices
=
src_to_dists
for
k_cache
,
v_cache
in
kv_caches
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
k_cache
,
True
)
k_cache
[:,
dst_indices
]
=
k_cache
[:,
src_indices
]
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
dataclass
...
...
@@ -50,8 +54,8 @@ class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
...
...
@@ -87,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
...
...
@@ -105,13 +110,16 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
logits_soft_cap
is
not
None
:
raise
NotImplementedError
(
"Attention logits soft-capping is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_type
=
torch_xla
.
tpu
.
get_tp
_group
u_env
()[
"TYPE"
].
lower
()
if
not
tpu_type
.
endswith
(
"lite"
)
:
tpu_type
=
torch_xla
.
tpu
.
get_tpu_env
()[
"TYPE"
].
lower
()
if
"lite"
not
in
tpu_type
:
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
...
...
@@ -126,7 +134,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]],
attn_metadata
:
PallasMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
...
...
@@ -140,7 +150,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert
kv_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
e7c1b7f3
...
...
@@ -6,7 +6,8 @@ import torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
...
...
@@ -25,8 +26,12 @@ class ROCmFlashAttentionBackend(AttentionBackend):
return
ROCmFlashAttentionImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"ROCmFlashAttentionMetadata"
:
return
ROCmFlashAttentionMetadata
(
*
args
,
**
kwargs
)
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
ROCmFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
return
ROCmFlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -166,6 +171,43 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return
self
.
_cached_decode_metadata
class
ROCmFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
_metadata_cls
=
ROCmFlashAttentionMetadata
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
Optional
[
List
[
int
]],
make_attn_mask
:
bool
=
True
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
if
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
(
(
num_heads
,
1
,
1
)).
to
(
alibi_slopes
.
device
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
make_attn_mask
:
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
).
to
(
alibi_slopes
.
device
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
else
:
attn_biases
.
append
(
bias
.
to
(
dtype
))
return
attn_biases
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
@@ -202,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"ROCFlashAttention does not support blocksparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support blocksparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support attention logits soft "
"capping."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -228,7 +276,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >
=
8000
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
...
...
@@ -246,6 +294,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
...
...
@@ -279,7 +333,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -292,6 +348,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -312,7 +374,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
...
...
@@ -338,9 +401,16 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_masks
=
None
if
self
.
use_triton_flash_attn
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
False
)
# type: ignore
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>
=
8000
:
if
prefill_meta
.
max_prefill_seq_len
>
8000
:
out
=
self
.
attn_func_triton
(
q
=
query
,
k
=
key
,
...
...
@@ -365,35 +435,40 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal
=
True
,
)
else
:
#
out, _
= self.attn_func(
#
out
= self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# prefill_meta.seq_lens,
# num_tokens,
# self.num_heads,
# self.head_size,
# self.scale,
# attn_masks,
# )
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
True
)
# type: ignore
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
...
...
@@ -407,6 +482,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
num_heads
,
self
.
head_size
,
self
.
scale
,
attn_masks
,
)
else
:
out
=
self
.
attn_func
(
...
...
@@ -419,6 +495,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
# window_size=self.sliding_window,
# alibi_slopes=self.alibi_slopes,
)
# common code for prefill
...
...
@@ -454,7 +532,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
)
# Reshape the output tensor.
...
...
@@ -470,13 +549,14 @@ def _sdpa_attention(
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
start
=
0
output
=
torch
.
empty
((
num_tokens
,
num_heads
,
head_size
),
dtype
=
query
.
dtype
,
device
=
query
.
device
)
for
seq_len
in
seq_lens
:
for
i
,
seq_len
in
enumerate
(
seq_lens
)
:
end
=
start
+
seq_len
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_math
=
True
,
enable_flash
=
False
,
...
...
@@ -486,7 +566,8 @@ def _sdpa_attention(
key
[:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
dropout_p
=
0.0
,
is_causal
=
True
,
is_causal
=
attn_masks
is
None
,
attn_mask
=
attn_masks
[
i
]
if
attn_masks
else
None
,
scale
=
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
...
...
vllm/attention/backends/torch_sdpa.py
View file @
e7c1b7f3
...
...
@@ -7,7 +7,7 @@ import torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.utils
import
is_cpu
...
...
@@ -31,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
return
TorchSDPABackendImpl
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"TorchSDPA
Metadata"
:
return
TorchSDPAMetadata
(
*
args
,
**
kwargs
)
def
get
_metadata
_cls
()
->
Type
[
"Attention
Metadata"
]
:
return
TorchSDPAMetadata
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support logits soft cap."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -144,7 +148,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -157,7 +163,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
kv_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -170,7 +181,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
)
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
...
...
@@ -233,7 +245,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
)
# Reshape the output tensor.
...
...
@@ -245,7 +258,7 @@ def _make_alibi_bias(
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
...
...
@@ -271,7 +284,7 @@ def _make_sliding_window_bias(
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
...
...
vllm/attention/backends/utils.py
0 → 100644
View file @
e7c1b7f3
"""Attention backend utils"""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Type
,
TypeVar
,
Union
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.utils
import
make_tensor_with_pad
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
"with encoder/decoder models."
)
PAD_SLOT_ID
=
-
1
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
def
is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
def
compute_slot_mapping_start_idx
(
is_prompt
:
bool
,
query_len
:
int
,
context_len
:
int
,
sliding_window
:
int
,
use_v2_block_manager
:
bool
):
"""
Compute the start index of slot mapping.
"""
start_idx
=
0
if
is_prompt
and
sliding_window
is
not
None
:
assert
use_v2_block_manager
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager"
)
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx
=
max
(
0
,
query_len
-
sliding_window
)
return
start_idx
def
compute_slot_mapping
(
is_profile_run
:
bool
,
slot_mapping
:
List
[
int
],
seq_id
:
int
,
seq_len
:
int
,
context_len
:
int
,
start_idx
:
int
,
block_size
:
int
,
block_tables
:
Dict
[
int
,
List
[
int
]]):
"""
Compute slot mapping.
"""
if
is_profile_run
:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
seq_len
)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table
=
block_tables
[
seq_id
]
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
max
(
0
,
start_idx
-
context_len
))
for
i
in
range
(
max
(
start_idx
,
context_len
),
seq_len
):
block_number
=
block_table
[
i
//
block_size
]
block_offset
=
i
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
TAttentionMetadata
=
TypeVar
(
"TAttentionMetadata"
,
bound
=
'AttentionMetadata'
)
class
CommonMetadataBuilder
(
AttentionMetadataBuilder
[
TAttentionMetadata
]):
_metadata_cls
:
Type
[
TAttentionMetadata
]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
computed_block_nums
=
inter_data
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
vllm/attention/backends/xformers.py
View file @
e7c1b7f3
...
...
@@ -6,10 +6,12 @@ import torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
AttentionBias
,
BlockDiagonalCausalMask
,
BlockDiagonalMask
,
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
...
...
@@ -28,8 +30,12 @@ class XFormersBackend(AttentionBackend):
return
XFormersImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"XFormersMetadata"
:
return
XFormersMetadata
(
*
args
,
**
kwargs
)
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
XFormersMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
return
XFormersMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -66,11 +72,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
...
...
@@ -79,8 +80,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
...
...
@@ -88,26 +90,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
=
None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
...
...
@@ -115,6 +146,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
cross_attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
...
...
@@ -122,30 +175,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
# Recover cached prefill-phase attention
# metadata structure
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
assert
((
self
.
seq_lens
is
not
None
)
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
# Construct & cache prefill-phase attention metadata structure
self
.
_cached_prefill_metadata
=
XFormersMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
]
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
]
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
]
,
slot_mapping
=
slot_mapping
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
None
,
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
query_start_loc
=
query_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
@
property
...
...
@@ -154,29 +227,151 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
# Recover cached decode-phase attention
# metadata structure
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
# Construct & cache decode-phase attention metadata structure
self
.
_cached_decode_metadata
=
XFormersMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
slot_mapping
=
slot_mapping
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
block_tables
=
block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
def
_get_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_type
:
AttentionType
,
)
->
Optional
[
AttentionBias
]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
attn_type
==
AttentionType
.
DECODER
:
return
attn_metadata
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
attn_metadata
.
encoder_attn_bias
else
:
# attn_type == AttentionType.ENCODER_DECODER
return
attn_metadata
.
cross_attn_bias
def
_set_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_bias
:
List
[
Optional
[
AttentionBias
]],
attn_type
:
AttentionType
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
attn_type
==
AttentionType
.
DECODER
:
attn_metadata
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
attn_metadata
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
attn_metadata
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_seq_len_block_table_args
(
attn_metadata
:
XFormersMetadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_lens_tensor
,
max_seq_len
,
attn_metadata
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
_metadata_cls
=
XFormersMetadata
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
@@ -213,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"XFormer does not support block-sparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"XFormers does not support attention logits soft capping."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -238,51 +438,145 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
,
value
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
"XFormersMetadata"
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns:
shape = [num_tokens, num_heads * head_size]
"""
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
is
not
None
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
...
...
@@ -294,10 +588,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# block tables are empty if the prompt does not have a cached
# prefix.
out
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
prefill_meta
)
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
else
:
assert
prefill_meta
.
query_start_loc
is
not
None
assert
prefill_meta
.
max_query_len
is
not
None
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
...
...
@@ -320,18 +618,26 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
output
[:
num_prefill_tokens
]
=
out
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
_get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_
tensor
,
decode_meta
.
max_decode
_seq_len
,
block_tables
_arg
,
seq_lens_
arg
,
max
_seq_len
_arg
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
)
# Reshape the output tensor.
...
...
@@ -343,6 +649,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
XFormersMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
...
...
@@ -356,8 +663,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
"""
assert
attn_metadata
.
seq_lens
is
not
None
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
...
...
@@ -375,18 +686,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if
attn_metadata
.
attn_bias
is
None
:
attn_bias
=
_get_attn_bias
(
attn_metadata
,
attn_type
)
if
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
)
if
(
attn_type
==
AttentionType
.
ENCODER_DECODER
):
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
encoder_seq_lens
is
not
None
# Default enc/dec cross-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
,
attn_metadata
.
encoder_seq_lens
)
elif
attn_type
==
AttentionType
.
ENCODER
:
assert
attn_metadata
.
encoder_seq_lens
is
not
None
# Default encoder self-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
encoder_seq_lens
)
else
:
assert
attn_metadata
.
seq_lens
is
not
None
# Default decoder self-attention mask is causal
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
attn_metadata
.
attn_bias
=
[
attn_bias
]
attn_bias
=
[
attn_bias
]
else
:
attn_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
assert
attn_metadata
.
seq_lens
is
not
None
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
_set_attn_bias
(
attn_metadata
,
attn_bias
,
attn_type
)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
...
...
@@ -400,7 +732,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query
,
key
,
value
,
attn_bias
=
attn_
metadata
.
attn_
bias
[
0
],
attn_bias
=
attn_bias
[
0
],
p
=
0.0
,
scale
=
self
.
scale
)
return
out
.
view_as
(
original_query
)
...
...
@@ -409,6 +741,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
assert
attn_metadata
.
seq_lens
is
not
None
output
=
torch
.
empty_like
(
original_query
)
start
=
0
for
i
,
seq_len
in
enumerate
(
attn_metadata
.
seq_lens
):
...
...
@@ -417,7 +750,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
attn_
metadata
.
attn_
bias
[
i
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
...
...
@@ -431,8 +764,8 @@ def _make_alibi_bias(
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
L
owerTriangularMaskWithTensor
Bias
:
attn_biases
=
[]
)
->
L
ist
[
Attention
Bias
]
:
attn_biases
:
List
[
AttentionBias
]
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
...
...
vllm/attention/layer.py
View file @
e7c1b7f3
...
...
@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
class
Attention
(
nn
.
Module
):
...
...
@@ -33,6 +34,8 @@ class Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
if
cache_config
is
not
None
:
...
...
@@ -46,23 +49,27 @@ class Attention(nn.Module):
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
# The default kv_scale is set to 1.0. This is ignored
# The default k
/
v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# expect the pre-quantized k
/
v_scale to be loaded along
# with the model weights.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
_kv_scale
=
1.0
self
.
_k_scale
=
1.0
self
.
_v_scale
=
1.0
quant_method
=
quant_config
.
get_quant_method
(
self
)
if
quant_config
else
None
self
,
prefix
=
prefix
)
if
quant_config
else
None
if
quant_method
is
not
None
:
assert
isinstance
(
quant_method
,
BaseKVCacheMethod
)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if
self
.
kv_cache_dtype
==
"fp8_e5m2"
:
raise
ValueError
(
"fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints."
)
#
When FP8
quantization is enabled, we make
a parameter
#
"kv_scale"
so that it can be loaded from
FP8
checkpoint.
# The kv_scale will then be converted back
#
to self._kv_scale in a native float32
value after weight loading.
#
If
quantization is enabled, we make
"k_scale" and "v_scale"
#
parameters
so that it can be loaded from
the model
checkpoint.
# The k
/
v_scale will then be converted back
to native float32
# value
s
after weight loading.
self
.
quant_method
=
quant_method
self
.
quant_method
.
create_weights
(
self
)
...
...
@@ -76,7 +83,7 @@ class Attention(nn.Module):
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
)
blocksparse_params
,
logits_soft_cap
)
def
forward
(
self
,
...
...
@@ -85,9 +92,17 @@ class Attention(nn.Module):
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_kv_scale
)
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
...
vllm/attention/ops/blocksparse_attention/interface.py
View file @
e7c1b7f3
...
...
@@ -2,13 +2,14 @@ import math
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
IS_COMPUTE_8_OR_ABOVE
=
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
)
and
current_platform
.
get_device_capability
()[
0
]
>=
8
)
if
IS_COMPUTE_8_OR_ABOVE
:
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
...
...
@@ -235,4 +236,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
v
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
)
\ No newline at end of file
sm_scale
=
sm_scale
)
vllm/attention/ops/blocksparse_attention/utils.py
View file @
e7c1b7f3
...
...
@@ -4,9 +4,35 @@
from
functools
import
lru_cache
import
numpy
as
np
import
torch
import
triton
from
scipy
import
sparse
class
csr_matrix
:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def
__init__
(
self
,
input_array
):
if
not
isinstance
(
input_array
,
np
.
ndarray
):
raise
ValueError
(
"Input must be a NumPy array"
)
self
.
shape
=
input_array
.
shape
rows
,
cols
=
self
.
shape
data
=
[]
indices
=
[]
indptr
=
[
0
]
for
i
in
range
(
rows
):
for
j
in
range
(
cols
):
if
input_array
[
i
,
j
]:
data
.
append
(
input_array
[
i
,
j
])
indices
.
append
(
j
)
indptr
.
append
(
len
(
indices
))
self
.
data
=
np
.
array
(
data
)
self
.
indices
=
np
.
array
(
indices
)
self
.
indptr
=
np
.
array
(
indptr
)
def
dense_to_crow_col
(
x
:
torch
.
Tensor
):
...
...
@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
assert
x
.
dim
()
in
(
2
,
3
)
if
x
.
dim
()
==
2
:
x
=
x
[
None
]
x
=
[
sparse
.
csr_matrix
(
xi
.
bool
().
cpu
().
numpy
())
for
xi
in
x
]
x
=
[
csr_matrix
(
xi
.
bool
().
cpu
().
numpy
())
for
xi
in
x
]
crows
=
torch
.
vstack
([
torch
.
from_numpy
(
xi
.
indptr
)
for
xi
in
x
])
cols
=
[
torch
.
from_numpy
(
xi
.
indices
)
for
xi
in
x
]
max_cols
=
max
(
len
(
xi
)
for
xi
in
cols
)
...
...
@@ -77,11 +103,11 @@ def _get_sparse_attn_mask_homo_head(
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with
torch
.
no_grad
():
...
...
@@ -148,10 +174,10 @@ def get_sparse_attn_mask(
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert
dense_mask_type
in
(
"binary"
,
"bias"
)
...
...
vllm/attention/ops/ipex_attn.py
View file @
e7c1b7f3
...
...
@@ -45,7 +45,8 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
...
...
@@ -64,7 +65,8 @@ class PagedAttention:
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
*
args
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
...
...
vllm/attention/ops/paged_attn.py
View file @
e7c1b7f3
...
...
@@ -4,9 +4,12 @@ from typing import List, Optional, Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.
attention.ops.prefix_pref
il
l
import
context_attention_fwd
from
vllm.
triton_ut
il
s
import
HAS_TRITON
import
vllm.envs
as
envs
if
HAS_TRITON
:
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
...
...
@@ -32,7 +35,7 @@ class PagedAttention:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
return
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -67,7 +70,8 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
ops
.
reshape_and_cache
(
key
,
...
...
@@ -76,7 +80,8 @@ class PagedAttention:
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
@
staticmethod
...
...
@@ -91,7 +96,8 @@ class PagedAttention:
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
...
@@ -127,7 +133,7 @@ class PagedAttention:
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
paged_attention_v1_opt
(
output
,
...
...
@@ -142,7 +148,8 @@ class PagedAttention:
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
@@ -163,7 +170,8 @@ class PagedAttention:
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
@@ -190,7 +198,7 @@ class PagedAttention:
print
(
f
"exp_sums.shape =
{
exp_sums
.
shape
}
, max_logits.shape =
{
max_logits
.
shape
}
, tmp_output.shape =
{
tmp_output
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
paged_attention_v2_opt
(
output
,
...
...
@@ -208,7 +216,8 @@ class PagedAttention:
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
@@ -232,7 +241,8 @@ class PagedAttention:
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
@@ -296,4 +306,4 @@ class PagedAttention:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
\ No newline at end of file
vllm/attention/ops/prefix_prefill.py
View file @
e7c1b7f3
...
...
@@ -5,6 +5,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.platforms
import
current_platform
if
triton
.
__version__
>=
"2.1.0"
:
@
triton
.
jit
...
...
@@ -683,8 +685,14 @@ if triton.__version__ >= "2.1.0":
alibi_slopes
=
None
,
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
cap
=
current_platform
.
get_device_capability
()
BLOCK
=
32
if
cap
[
0
]
>=
8
else
32
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if
q
.
dtype
is
torch
.
float32
:
BLOCK
=
BLOCK
//
2
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
...
...
@@ -716,7 +724,7 @@ if triton.__version__ >= "2.1.0":
b_ctx_len
,
alibi_slopes
,
v_cache
.
shape
[
3
],
8
,
k_cache
.
shape
[
4
]
,
o
,
b_loc
.
stride
(
0
),
b_loc
.
stride
(
1
),
...
...
@@ -766,7 +774,7 @@ if triton.__version__ >= "2.1.0":
b_seq_len
,
b_ctx_len
,
v_cache
.
shape
[
3
],
8
,
k_cache
.
shape
[
4
]
,
o
,
b_loc
.
stride
(
0
),
b_loc
.
stride
(
1
),
...
...
@@ -802,4 +810,4 @@ if triton.__version__ >= "2.1.0":
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
return
\ No newline at end of file
vllm/attention/selector.py
View file @
e7c1b7f3
...
...
@@ -7,7 +7,8 @@ import torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_cpu
,
is_hip
,
is_tpu
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_openvino
,
is_tpu
,
is_xpu
logger
=
init_logger
(
__name__
)
...
...
@@ -17,8 +18,10 @@ class _Backend(enum.Enum):
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
OPENVINO
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
...
...
@@ -58,16 +61,23 @@ def get_attn_backend(
ROCmFlashAttentionBackend
)
return
ROCmFlashAttentionBackend
elif
backend
==
_Backend
.
TORCH_SDPA
:
# TODO: make XPU backend available here.
assert
is_cpu
(),
RuntimeError
(
"Torch SDPA backend is only used for the CPU device."
)
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
OPENVINO
:
logger
.
info
(
"Using OpenVINO Attention backend."
)
from
vllm.attention.backends.openvino
import
OpenVINOAttentionBackend
return
OpenVINOAttentionBackend
elif
backend
==
_Backend
.
IPEX
:
assert
is_xpu
(),
RuntimeError
(
"IPEX attention backend is only used for the XPU device."
)
logger
.
info
(
"Using IPEX attention backend."
)
from
vllm.attention.backends.ipex_attn
import
IpexAttnBackend
return
IpexAttnBackend
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
elif
backend
==
_Backend
.
PALLAS
:
...
...
@@ -107,6 +117,16 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
if
is_openvino
():
if
selected_backend
!=
_Backend
.
OPENVINO
:
logger
.
info
(
"Cannot use %s backend on OpenVINO."
,
selected_backend
)
return
_Backend
.
OPENVINO
if
is_xpu
():
if
selected_backend
!=
_Backend
.
IPEX
:
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
return
_Backend
.
IPEX
if
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
...
...
@@ -117,7 +137,7 @@ def which_attn_to_use(
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
...
...
@@ -126,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
current_platform
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
...
...
Prev
1
…
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment