Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b57e6c59
Unverified
Commit
b57e6c59
authored
May 19, 2024
by
Woosuk Kwon
Committed by
GitHub
May 19, 2024
Browse files
[Kernel] Add flash-attn back (#4907)
parent
27ce8547
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
304 additions
and
61 deletions
+304
-61
requirements-cuda.txt
requirements-cuda.txt
+1
-1
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+208
-0
tests/models/test_big_models.py
tests/models/test_big_models.py
+1
-1
tests/models/test_fp8.py
tests/models/test_fp8.py
+5
-5
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+75
-54
vllm/attention/selector.py
vllm/attention/selector.py
+14
-0
No files found.
requirements-cuda.txt
View file @
b57e6c59
...
...
@@ -7,4 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post
1
# Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post
2
# Requires PyTorch 2.3.0
tests/kernels/test_flash_attn.py
0 → 100644
View file @
b57e6c59
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
NUM_HEADS
=
[(
16
,
16
),
(
32
,
8
),
(
64
,
8
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
]
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
]
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
]
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
NUM_BLOCKS
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
block_size
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
window_size
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
NUM_BLOCKS
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
sliding_window
=
sliding_window
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/models/test_big_models.py
View file @
b57e6c59
...
...
@@ -12,7 +12,7 @@ MODELS = [
# "Deci/DeciLM-7b", # Broken
# "tiiuae/falcon-7b", # Broken
"EleutherAI/gpt-j-6b"
,
"mosaicml/mpt-7b"
,
#
"mosaicml/mpt-7b",
# Broken
# "Qwen/Qwen1.5-0.5B" # Broken,
]
...
...
tests/models/test_fp8.py
View file @
b57e6c59
...
...
@@ -25,18 +25,18 @@ EXPECTED_STRS_MAP = {
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models ('
,
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to '
,
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.'
,
'A neural network is a complex system modeled after the human brain, co
mposed
of interconnected nodes or "ne'
,
'Zeta-5, a highly advanced robot designed for menial labor, whirred
and beep
'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models.
Her
e'
,
'A neural network is a complex system modeled after the human brain, co
nsisting
of interconnected nodes or "ne'
,
'Zeta-5, a highly advanced robot designed for menial labor, whirred
to a
'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models.
Th
e'
,
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of'
,
'Here are the translations:
\n\n
**Japanese:** (Haya tori,
nem
uri
nemuri)
\n\n
**'
'Here are the translations:
\n\n
**Japanese:** (Haya
aki no
tori,
g
uri
o'
,
],
"meta-llama/Meta-Llama-3-8B-Instruct"
:
[
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained'
,
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to '
,
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.'
,
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne'
,
'In the
year 2154, the robotics lab at NeuroSpark Industries was on the cusp of
'
,
'In the
vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short
'
,
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The'
,
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of'
,
'Here are the translations:
\n\n
**Japanese:** (Haya aki wa mushi o tsukamu'
...
...
vllm/attention/backends/flash_attn.py
View file @
b57e6c59
"""Attention layer with Flash and PagedAttention.
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm._C
import
cache_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
_SUPPORTED_HEAD_SIZES
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
class
FlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -37,8 +33,9 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
...
...
@@ -46,18 +43,26 @@ class FlashAttentionBackend(AttentionBackend):
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
class
FlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
FlashAttentionMetadata
(
AttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
...
...
@@ -99,6 +104,14 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
...
...
@@ -219,11 +232,15 @@ class FlashAttentionImpl(AttentionImpl):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
if
sliding_window
is
not
None
:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise
ValueError
(
"Sliding window is not supported in FlashAttention."
)
if
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
Paged
Attention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Head size
{
head_size
}
is not supported by
Flash
Attention. "
f
"Supported head sizes are:
{
_SUPPORTED_HEAD_SIZES
}
."
)
def
forward
(
self
,
...
...
@@ -234,17 +251,20 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata
:
FlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention
and PagedAttention
.
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size
*
num_kv_heads
*
head_size]
kv_cache = [2, num_blocks, block_size
,
num_kv_heads
,
head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
kv_scale
==
1.0
,
"kv_scale is not supported in FlashAttention."
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -252,16 +272,20 @@ class FlashAttentionImpl(AttentionImpl):
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
PagedAttention
.
split_
kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
)
cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
@@ -281,7 +305,8 @@ class FlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
...
...
@@ -302,38 +327,34 @@ class FlashAttentionImpl(AttentionImpl):
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
key
,
value
,
key_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
context_lens_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decod
e
(
decode_query
,
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcach
e
(
decode_query
.
unsqueeze
(
1
)
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_decode_seq_len
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
).
squeeze
(
1
)
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/selector.py
View file @
b57e6c59
...
...
@@ -93,6 +93,20 @@ def _which_attn_to_use(
"torch.float16 or torch.bfloat16."
)
return
_Backend
.
XFORMERS
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
):
logger
.
info
(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
return
_Backend
.
XFORMERS
if
block_size
%
16
!=
0
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16."
)
return
_Backend
.
XFORMERS
if
sliding_window
is
not
None
:
logger
.
info
(
"Cannot use FlashAttention-2 backend due to sliding window."
)
return
_Backend
.
XFORMERS
try
:
import
vllm_flash_attn
# noqa: F401
except
ImportError
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment