Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
928de468
"vscode:/vscode.git/clone" did not exist on "7ad6b766c589cc51f4716b1d2052d66ac1a135fb"
Unverified
Commit
928de468
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Implement PagedAttention V2 (#1348)
parent
29678cd2
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
764 additions
and
139 deletions
+764
-139
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+197
-0
csrc/attention.cpp
csrc/attention.cpp
+24
-4
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+413
-71
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+5
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+54
-17
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+71
-47
No files found.
benchmarks/kernels/benchmark_paged_attention.py
0 → 100644
View file @
928de468
import
argparse
import
random
import
time
import
torch
from
vllm
import
attention_ops
NUM_BLOCKS
=
1024
PARTITION_SIZE
=
512
@
torch
.
inference_mode
()
def
main
(
version
:
str
,
num_seqs
:
int
,
context_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
do_profile
:
bool
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
context_lens
=
[
context_len
for
_
in
range
(
num_seqs
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the KV cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
key_cache
.
uniform_
(
-
scale
,
scale
)
value_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
value_cache
.
uniform_
(
-
scale
,
scale
)
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
def
run_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
elif
version
==
"v2"
:
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
(
num_iters
=
3
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
100
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--context-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
default
=
128
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num_query_heads
%
args
.
num_kv_heads
!=
0
:
raise
ValueError
(
"num_query_heads must be divisible by num_kv_heads"
)
dtype_to_torch_dtype
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
context_len
=
args
.
context_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
use_alibi
=
args
.
use_alibi
,
dtype
=
dtype_to_torch_dtype
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
)
csrc/attention.cpp
View file @
928de468
#include <torch/extension.h>
#include <c10/util/Optional.h>
void
single_query_cached_kv
_attention
(
void
paged
_attention
_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
...
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_cached_kv_attention"
,
&
single_query_cached_kv_attention
,
"Compute the attention between an input query and the cached key/value tensors"
);
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
m
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
}
csrc/attention/attention_kernels.cu
View file @
928de468
This diff is collapsed.
Click to expand it.
csrc/attention/dtype_bfloat16.cuh
View file @
928de468
...
...
@@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif
}
// From bfloat16 to float32.
inline
__device__
float
to_float
(
__nv_bfloat16
u
)
{
return
__bfloat162float
(
u
);
}
// Zero-out a variable.
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
tests/kernels/test_attention.py
View file @
928de468
...
...
@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
1
,
3
,
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
SEEDS
=
[
0
]
...
...
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_single_query_cached_kv_attention
(
def
test_paged_attention
(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
...
...
@@ -162,8 +164,41 @@ def test_single_query_cached_kv_attention(
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
attention_ops
.
single_query_cached_kv_attention
(
if
version
==
"v1"
:
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
...
...
@@ -175,6 +210,8 @@ def test_single_query_cached_kv_attention(
max_context_len
,
alibi_slopes
,
)
else
:
assert
False
,
f
"Unknown version:
{
version
}
"
# Run the reference implementation.
ref_output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/layers/attention.py
View file @
928de468
...
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding
)
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttention
(
nn
.
Module
):
...
...
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
output
.
copy_
(
out
.
squeeze
(
0
))
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the slopes for the alibi attention bias.
Returns:
slopes: shape = [num_heads]
"""
return
None
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
...
...
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
"""PagedAttention for the generation tokens.
...
...
@@ -148,10 +159,54 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
alibi_slopes: shape = [num_heads]
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1
=
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
if
use_v1
:
# Run PagedAttention V1.
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
...
...
@@ -161,7 +216,7 @@ class PagedAttention(nn.Module):
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
None
,
#
alibi_slopes
alibi_slopes
,
)
def
forward
(
...
...
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
()
)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
...
...
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
start
+=
prompt_len
return
output
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
self
.
alibi_slopes
,
)
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
alibi_slopes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment