Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
3e9f991d
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "97b02a8981110c76d901e8b5f96af514ee0326f3"
Unverified
Commit
3e9f991d
authored
Mar 01, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 01, 2023
Browse files
Use FlashAttention for `multi_query_kv_attention` (#4)
parent
0deacbce
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
34 deletions
+108
-34
README.md
README.md
+1
-0
cacheflow/models/attention.py
cacheflow/models/attention.py
+37
-30
server.py
server.py
+5
-2
tests/kernels/attention.py
tests/kernels/attention.py
+65
-2
No files found.
README.md
View file @
3e9f991d
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
```
bash
```
bash
pip
install
cmake torch transformers
pip
install
cmake torch transformers
pip
install
flash-attn
# This may take up to 10 mins.
pip
install
-e
.
pip
install
-e
.
```
```
...
...
cacheflow/models/attention.py
View file @
3e9f991d
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
flash_attn.flash_attention
import
FlashAttention
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
def
_masked_attention
(
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
self
,
query
:
torch
.
Tensor
,
# [num_queries, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_keys, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_keys, num_heads, head_size]
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# [num_queries, num_keys]
)
->
torch
.
Tensor
:
# [num_queries, num_heads, head_size]
query
=
query
*
self
.
scale
attn
=
torch
.
einsum
(
'qhd,khd->hqk'
,
query
,
key
)
if
attn_mask
is
not
None
:
attn
=
attn
+
attn_mask
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
)
out
=
torch
.
einsum
(
'hqk,khd->qhd'
,
attn
,
value
)
return
out
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
...
@@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module):
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
)
->
None
:
)
->
None
:
# FIXME(woosuk): Replace the following with a custom op.
if
query
.
dtype
==
torch
.
float
:
start_idx
=
0
raise
ValueError
(
'The float data type is not supported by '
'FlashAttention. Use the half data type instead.'
)
head_size
=
query
.
shape
[
2
]
if
head_size
>
128
:
raise
ValueError
(
'FlashAttention does not support head_size > 128.'
)
device
=
query
.
device
prefix_sum
=
[
0
]
for
prompt_len
in
prompt_lens
:
for
prompt_len
in
prompt_lens
:
out
=
output
[
start_idx
:
start_idx
+
prompt_len
]
prefix_sum
.
append
(
prefix_sum
[
-
1
]
+
prompt_len
)
q
=
query
[
start_idx
:
start_idx
+
prompt_len
]
prefix_sum
=
torch
.
tensor
(
prefix_sum
,
dtype
=
torch
.
int
,
device
=
device
)
k
=
key
[
start_idx
:
start_idx
+
prompt_len
]
max_prompt_len
=
max
(
prompt_lens
)
v
=
value
[
start_idx
:
start_idx
+
prompt_len
]
# FIXME(woosuk): Unnecessary copy. Optimize this.
attention_mask
=
torch
.
triu
(
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
torch
.
ones
(
q
.
shape
[
0
],
k
.
shape
[
0
]),
diagonal
=
1
)
*
-
1e5
out
=
self
.
flash_attn
(
attention_mask
=
attention_mask
.
to
(
dtype
=
q
.
dtype
,
device
=
q
.
device
)
qkv
,
attention_out
=
self
.
_masked_attention
(
q
,
k
,
v
,
attention_mask
)
cu_seqlens
=
prefix_sum
,
out
.
copy_
(
attention_out
,
non_blocking
=
True
)
max_s
=
max_prompt_len
,
causal
=
True
,
start_idx
+=
prompt_len
)[
0
]
num_tokens
=
prefix_sum
[
-
1
]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output
[:
num_tokens
].
copy_
(
out
,
non_blocking
=
True
)
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
...
@@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module):
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
None
:
)
->
None
:
head_size
=
value_cache
.
shape
[
2
]
supported_head_sizes
=
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
'head_size (
{
head_size
}
) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f
'
{
supported_head_sizes
}
.'
)
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
attention_ops
.
single_query_cached_kv_attention
(
output
,
output
,
...
@@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module):
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
# Compute the attention op for prompts.
# Compute the attention op for prompts.
self
.
multi_query_kv_attention
(
if
input_metadata
.
num_prompts
>
0
:
output
,
query
,
key
,
value
,
input_metadata
.
prompt_lens
)
self
.
multi_query_kv_attention
(
output
,
query
,
key
,
value
,
input_metadata
.
prompt_lens
)
# Wait until the cache op is done.
# Wait until the cache op is done.
if
cache_event
is
not
None
:
if
cache_event
is
not
None
:
...
...
server.py
View file @
3e9f991d
...
@@ -9,10 +9,12 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
...
@@ -9,10 +9,12 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--num-nodes'
,
type
=
int
,
default
=
1
,
help
=
'number of nodes'
)
parser
.
add_argument
(
'--num-nodes'
,
type
=
int
,
default
=
1
,
help
=
'number of nodes'
)
parser
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
1
,
help
=
'number of workers per node'
)
parser
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
1
,
help
=
'number of workers per node'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser
.
add_argument
(
'--num-gpu-blocks'
,
type
=
int
,
default
=
1024
,
help
=
'number of GPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-gpu-blocks'
,
type
=
int
,
default
=
1024
,
help
=
'number of GPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
256
,
help
=
'number of CPU blocks (per GPU)'
)
parser
.
add_argument
(
'--num-cpu-blocks'
,
type
=
int
,
default
=
32
,
help
=
'number of CPU blocks (per GPU)'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -27,6 +29,7 @@ def main():
...
@@ -27,6 +29,7 @@ def main():
block_size
=
args
.
block_size
,
block_size
=
args
.
block_size
,
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_gpu_blocks
=
args
.
num_gpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
num_cpu_blocks
=
args
.
num_cpu_blocks
,
dtype
=
args
.
dtype
,
)
)
controllers
.
append
(
controller
)
controllers
.
append
(
controller
)
...
...
tests/kernels/attention.py
View file @
3e9f991d
import
random
import
random
from
typing
import
Optional
from
typing
import
Optional
from
flash_attn.flash_attention
import
FlashAttention
import
torch
import
torch
from
cacheflow
import
attention_ops
from
cacheflow
import
attention_ops
MAX_SEQ_LEN
=
4096
def
ref_masked_attention
(
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
...
@@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
value_cache
=
torch
.
randn
(
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
context_lens
=
[
random
.
randint
(
1
,
4096
)
for
_
in
range
(
num_tokens
)]
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
max_context_len
=
max
(
context_lens
)
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
...
@@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
...
@@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
max_seq_len
=
max
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
cu_seq_lens
.
append
(
cu_seq_lens
[
-
1
]
+
seq_len
)
cu_seq_lens
=
torch
.
tensor
(
cu_seq_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
rand_like
(
query
)
value
=
torch
.
rand_like
(
query
)
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
1
)
flash_attn
=
FlashAttention
(
softmax_scale
=
scale
)
output
=
flash_attn
(
qkv
,
cu_seqlens
=
cu_seq_lens
,
max_s
=
max_seq_len
,
causal
=
True
,
)[
0
]
ref_outputs
=
[]
for
i
,
seq_len
in
enumerate
(
seq_lens
):
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_attention
()
->
None
:
def
test_attention
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
block_size
in
[
8
,
16
]:
for
head_size
in
[
64
,
80
,
96
,
128
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
test_single_query_cached_kv_attention
(
test_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_tokens
=
37
,
num_heads
=
3
,
num_heads
=
3
,
...
@@ -137,6 +189,17 @@ def test_attention() -> None:
...
@@ -137,6 +189,17 @@ def test_attention() -> None:
dtype
=
dtype
,
dtype
=
dtype
,
)
)
# NOTE(woosuk): FlashAttention does not support FP32.
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for
head_size
in
[
64
,
80
,
96
,
128
]:
test_multi_query_kv_attention
(
num_seqs
=
11
,
num_heads
=
3
,
head_size
=
head_size
,
dtype
=
dtype
,
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_attention
()
test_attention
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment