Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
3e9f991d
Unverified
Commit
3e9f991d
authored
Mar 01, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 01, 2023
Browse files
Use FlashAttention for `multi_query_kv_attention` (#4)
parent
0deacbce
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
34 deletions
+108
-34
README.md
README.md
+1
-0
cacheflow/models/attention.py
cacheflow/models/attention.py
+37
-30
server.py
server.py
+5
-2
tests/kernels/attention.py
tests/kernels/attention.py
+65
-2
No files found.
README.md
View file @
3e9f991d
...
...
@@ -4,6 +4,7 @@
```
bash
pip
install
cmake torch transformers
pip
install
flash-attn
# This may take up to 10 mins.
pip
install
-e
.
```
...
...
cacheflow/models/attention.py
View file @
3e9f991d
from
typing
import
List
,
Optional
from
flash_attn.flash_attention
import
FlashAttention
import
torch
import
torch.nn
as
nn
...
...
@@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module):
super
().
__init__
()
self
.
scale
=
float
(
scale
)
def
_masked_attention
(
self
,
query
:
torch
.
Tensor
,
# [num_queries, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_keys, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_keys, num_heads, head_size]
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# [num_queries, num_keys]
)
->
torch
.
Tensor
:
# [num_queries, num_heads, head_size]
query
=
query
*
self
.
scale
attn
=
torch
.
einsum
(
'qhd,khd->hqk'
,
query
,
key
)
if
attn_mask
is
not
None
:
attn
=
attn
+
attn_mask
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
)
out
=
torch
.
einsum
(
'hqk,khd->qhd'
,
attn
,
value
)
return
out
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
def
multi_query_kv_attention
(
self
,
...
...
@@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module):
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
prompt_lens
:
List
[
int
],
)
->
None
:
# FIXME(woosuk): Replace the following with a custom op.
start_idx
=
0
if
query
.
dtype
==
torch
.
float
:
raise
ValueError
(
'The float data type is not supported by '
'FlashAttention. Use the half data type instead.'
)
head_size
=
query
.
shape
[
2
]
if
head_size
>
128
:
raise
ValueError
(
'FlashAttention does not support head_size > 128.'
)
device
=
query
.
device
prefix_sum
=
[
0
]
for
prompt_len
in
prompt_lens
:
out
=
output
[
start_idx
:
start_idx
+
prompt_len
]
q
=
query
[
start_idx
:
start_idx
+
prompt_len
]
k
=
key
[
start_idx
:
start_idx
+
prompt_len
]
v
=
value
[
start_idx
:
start_idx
+
prompt_len
]
attention_mask
=
torch
.
triu
(
torch
.
ones
(
q
.
shape
[
0
],
k
.
shape
[
0
]),
diagonal
=
1
)
*
-
1e5
attention_mask
=
attention_mask
.
to
(
dtype
=
q
.
dtype
,
device
=
q
.
device
)
attention_out
=
self
.
_masked_attention
(
q
,
k
,
v
,
attention_mask
)
out
.
copy_
(
attention_out
,
non_blocking
=
True
)
start_idx
+=
prompt_len
prefix_sum
.
append
(
prefix_sum
[
-
1
]
+
prompt_len
)
prefix_sum
=
torch
.
tensor
(
prefix_sum
,
dtype
=
torch
.
int
,
device
=
device
)
max_prompt_len
=
max
(
prompt_lens
)
# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
out
=
self
.
flash_attn
(
qkv
,
cu_seqlens
=
prefix_sum
,
max_s
=
max_prompt_len
,
causal
=
True
,
)[
0
]
num_tokens
=
prefix_sum
[
-
1
]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output
[:
num_tokens
].
copy_
(
out
,
non_blocking
=
True
)
def
single_query_cached_kv_attention
(
self
,
...
...
@@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module):
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
input_metadata
:
InputMetadata
,
)
->
None
:
head_size
=
value_cache
.
shape
[
2
]
supported_head_sizes
=
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
'head_size (
{
head_size
}
) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f
'
{
supported_head_sizes
}
.'
)
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
...
...
@@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module):
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
# Compute the attention op for prompts.
self
.
multi_query_kv_attention
(
output
,
query
,
key
,
value
,
input_metadata
.
prompt_lens
)
if
input_metadata
.
num_prompts
>
0
:
self
.
multi_query_kv_attention
(
output
,
query
,
key
,
value
,
input_metadata
.
prompt_lens
)
# Wait until the cache op is done.
if
cache_event
is
not
None
:
...
...
server.py
View file @
3e9f991d
...
...
@@ -9,10 +9,12 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--num-nodes'
,
type
=
int
,
default
=
1
,
help
=
'number of nodes'
)
parser
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
1
,
help
=
'number of workers per node'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser
.
add_argument
(
'--num-gpu-blocks'
,
type
=
int
,
default
=
1024
,
help
=
'number of GPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
256
,
help
=
'number of CPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
args
=
parser
.
parse_args
()
...
...
@@ -27,6 +29,7 @@ def main():
block_size
=
args
.
block_size
,
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
dtype
=
args
.
dtype
,
)
controllers
.
append
(
controller
)
...
...
tests/kernels/attention.py
View file @
3e9f991d
import
random
from
typing
import
Optional
from
flash_attn.flash_attention
import
FlashAttention
import
torch
from
cacheflow
import
attention_ops
MAX_SEQ_LEN
=
4096
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
...
...
@@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
context_lens
=
[
random
.
randint
(
1
,
4096
)
for
_
in
range
(
num_tokens
)]
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
...
...
@@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
max_seq_len
=
max
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
cu_seq_lens
.
append
(
cu_seq_lens
[
-
1
]
+
seq_len
)
cu_seq_lens
=
torch
.
tensor
(
cu_seq_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
rand_like
(
query
)
value
=
torch
.
rand_like
(
query
)
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
flash_attn
=
FlashAttention
(
softmax_scale
=
scale
)
output
=
flash_attn
(
qkv
,
cu_seqlens
=
cu_seq_lens
,
max_s
=
max_seq_len
,
causal
=
True
,
)[
0
]
ref_outputs
=
[]
for
i
,
seq_len
in
enumerate
(
seq_lens
):
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
def
test_attention
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
head_size
in
[
64
,
80
,
96
,
128
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
test_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_heads
=
3
,
...
...
@@ -137,6 +189,17 @@ def test_attention() -> None:
dtype
=
dtype
,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for
head_size
in
[
64
,
80
,
96
,
128
]:
test_multi_query_kv_attention
(
num_seqs
=
11
,
num_heads
=
3
,
head_size
=
head_size
,
dtype
=
dtype
,
)
if
__name__
==
'__main__'
:
test_attention
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment