Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
928de468
Unverified
Commit
928de468
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Implement PagedAttention V2 (#1348)
parent
29678cd2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
764 additions
and
139 deletions
+764
-139
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+197
-0
csrc/attention.cpp
csrc/attention.cpp
+24
-4
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+413
-71
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+5
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+54
-17
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+71
-47
No files found.
benchmarks/kernels/benchmark_paged_attention.py
0 → 100644
View file @
928de468
import
argparse
import
random
import
time
import
torch
from
vllm
import
attention_ops
NUM_BLOCKS
=
1024
PARTITION_SIZE
=
512
@
torch
.
inference_mode
()
def
main
(
version
:
str
,
num_seqs
:
int
,
context_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
do_profile
:
bool
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
context_lens
=
[
context_len
for
_
in
range
(
num_seqs
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the KV cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
key_cache
.
uniform_
(
-
scale
,
scale
)
value_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
value_cache
.
uniform_
(
-
scale
,
scale
)
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
def
run_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
elif
version
==
"v2"
:
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
(
num_iters
=
3
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
100
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--context-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
default
=
128
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num_query_heads
%
args
.
num_kv_heads
!=
0
:
raise
ValueError
(
"num_query_heads must be divisible by num_kv_heads"
)
dtype_to_torch_dtype
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
context_len
=
args
.
context_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
use_alibi
=
args
.
use_alibi
,
dtype
=
dtype_to_torch_dtype
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
)
csrc/attention.cpp
View file @
928de468
#include <torch/extension.h>
#include <c10/util/Optional.h>
void
single_query_cached_kv
_attention
(
void
paged
_attention
_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
...
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_cached_kv_attention"
,
&
single_query_cached_kv_attention
,
"Compute the attention between an input query and the cached key/value tensors"
);
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
m
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
}
csrc/attention/attention_kernels.cu
View file @
928de468
...
...
@@ -26,6 +26,7 @@
#define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
...
...
@@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
// Grid: (num_heads, num_seqs).
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
>
__global__
void
single_query_cached_kv_attention_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
int
NUM_THREADS
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
...
...
@@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
context_len
=
context_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
context_len
)
{
// No work to do. Terminate the thread block.
return
;
}
const
int
num_context_blocks
=
DIVIDE_ROUND_UP
(
context_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_context_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_context_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
context_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
(
BLOCK_SIZE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
)
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
...
...
@@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
kv_head_idx
=
head_mapping
[
head_idx
];
const
int
seq_idx
=
blockIdx
.
y
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
...
...
@@ -142,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel(
constexpr
int
x
=
16
/
sizeof
(
scalar_t
);
float
qk_max
=
-
FLT_MAX
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_blocks
=
(
context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num_blocks
;
block_idx
+=
NUM_WARPS
)
{
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
// Load a key to registers.
...
...
@@ -184,7 +208,7 @@ __global__ void single_query_cached_kv_attention_kernel(
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
context_len
;
logits
[
token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
...
...
@@ -215,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel(
// Get the sum of the exp values.
float
exp_sum
=
0.
f
;
for
(
int
i
=
thread_idx
;
i
<
context_l
en
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tok
en
s
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
exp_sum
+=
val
;
...
...
@@ -224,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel(
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
context_l
en
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tok
en
s
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
}
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
...
...
@@ -237,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel(
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
(
HEAD_SIZE
+
NUM_ROWS_PER_ITER
-
1
)
/
NUM_ROWS_PER_ITER
;
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
)
;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
...
...
@@ -248,12 +284,12 @@ __global__ void single_query_cached_kv_attention_kernel(
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num
_block
s
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end
_block
_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
));
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
start_token_idx
));
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
...
...
@@ -263,7 +299,7 @@ __global__ void single_query_cached_kv_attention_kernel(
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
if
(
block_idx
==
num_blocks
-
1
)
{
if
(
block_idx
==
num_
context_
blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
...
...
@@ -327,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel(
// Write the final output.
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
...
...
@@ -338,13 +376,167 @@ __global__ void single_query_cached_kv_attention_kernel(
}
}
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
>
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
*
__restrict__
head_mapping
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
paged_attention_kernel
<
scalar_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
*
__restrict__
head_mapping
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
paged_attention_kernel
<
scalar_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
);
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
}
// Terminate the thread block.
return
;
}
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Size: 2 * num_partitions.
extern
__shared__
char
shared_mem
[];
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Load max logits to shared memory.
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
const
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
max_logit
=
-
FLT_MAX
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
const
float
l
=
max_logits_ptr
[
i
];
shared_max_logits
[
i
]
=
l
;
max_logit
=
fmaxf
(
max_logit
,
l
);
}
__syncthreads
();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
max_logit
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
max_logit
;
}
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
max_logit
,
mask
));
}
// Broadcast the max value to all threads.
max_logit
=
__shfl_sync
(
uint32_t
(
-
1
),
max_logit
,
0
);
// Load rescaled exp sums to shared memory.
float
*
shared_exp_sums
=
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
global_exp_sum
=
0.0
f
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
l
=
shared_max_logits
[
i
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
global_exp_sum
+=
rescaled_exp_sum
;
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
}
__syncthreads
();
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Aggregate tmp_out to out.
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
// namespace vllm
#define LAUNCH_ATTENTION_
KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS)
\
#define LAUNCH_
PAGED_
ATTENTION_
V1(HEAD_SIZE)
\
cudaFuncSetAttribute( \
vllm::
single_query_cached_kv
_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::
single_query_cached_kv
_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
vllm::
paged
_attention_
v1_
kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>,
\
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);
\
vllm::
paged
_attention_
v1_
kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>
\
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
...
...
@@ -365,7 +557,7 @@ template<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_THREADS
=
128
>
void
single_query_cached_kv
_attention_launcher
(
void
paged
_attention_
v1_
launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
...
@@ -401,45 +593,206 @@ void single_query_cached_kv_attention_launcher(
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_context_len
=
(
(
max_context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_context_len
=
DIVIDE_ROUND_UP
(
max_context_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
dim3
grid
(
num_heads
,
num_seqs
);
dim3
grid
(
num_heads
,
num_seqs
,
1
);
dim3
block
(
NUM_THREADS
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case
64
:
LAUNCH_PAGED_ATTENTION_V1
(
64
);
break
;
case
80
:
LAUNCH_PAGED_ATTENTION_V1
(
80
);
break
;
case
96
:
LAUNCH_PAGED_ATTENTION_V1
(
96
);
break
;
case
112
:
LAUNCH_PAGED_ATTENTION_V1
(
112
);
break
;
case
128
:
LAUNCH_PAGED_ATTENTION_V1
(
128
);
break
;
case
256
:
LAUNCH_PAGED_ATTENTION_V1
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
break
;
}
}
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, 8); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, 16); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, 32); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
head_mapping
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
float
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
context_lens_ptr, \
max_num_partitions);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_THREADS
=
128
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
head_mapping_ptr
=
reinterpret_cast
<
int
*>
(
head_mapping
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// For paged attention v2 kernel.
dim3
grid
(
num_heads
,
num_seqs
,
max_num_partitions
);
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
// For paged attention v2 reduce kernel.
dim3
reduce_grid
(
num_heads
,
num_seqs
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
dim3
block
(
NUM_THREADS
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192.
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case
64
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
64
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
64
);
break
;
case
80
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
80
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
80
);
break
;
case
96
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
96
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
96
);
break
;
case
112
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
112
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
112
);
break
;
case
128
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
128
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
128
);
break
;
// case 160:
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
case
256
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
256
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
...
...
@@ -447,9 +800,12 @@ void single_query_cached_kv_attention_launcher(
}
}
#define CALL_
KERNEL
_LAUNCHER(T, BLOCK_SIZE) \
single_query_cached_kv
_attention_launcher<T, BLOCK_SIZE>( \
#define CALL_
V2
_LAUNCHER(T, BLOCK_SIZE)
\
paged
_attention_
v2_
launcher<T, BLOCK_SIZE>(
\
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
...
...
@@ -462,42 +818,27 @@ void single_query_cached_kv_attention_launcher(
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE(T) \
#define CALL_
V2
_LAUNCHER_BLOCK_SIZE(T)
\
switch (block_size) { \
/* case 1: */
\
/* CALL_KERNEL_LAUNCHER(T, 1); */
\
/* break; */
\
/* case 2: */
\
/* CALL_KERNEL_LAUNCHER(T, 2); */
\
/* break; */
\
/* case 4: */
\
/* CALL_KERNEL_LAUNCHER(T, 4); */
\
/* break; */
\
case 8: \
CALL_
KERNEL
_LAUNCHER(T, 8); \
CALL_
V2
_LAUNCHER(T, 8);
\
break; \
case 16: \
CALL_
KERNEL
_LAUNCHER(T, 16); \
CALL_
V2
_LAUNCHER(T, 16);
\
break; \
case 32: \
CALL_
KERNEL
_LAUNCHER(T, 32); \
CALL_
V2
_LAUNCHER(T, 32);
\
break; \
/* case 64: */
\
/* CALL_KERNEL_LAUNCHER(T, 64); */
\
/* break; */
\
/* case 128: */
\
/* CALL_KERNEL_LAUNCHER(T, 128); */
\
/* break; */
\
/* case 256: */
\
/* CALL_KERNEL_LAUNCHER(T, 256); */
\
/* break; */
\
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void
single_query_cached_kv
_attention
(
void
paged
_attention
_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
...
...
@@ -509,11 +850,11 @@ void single_query_cached_kv_attention(
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE
(
float
);
CALL_
V2
_LAUNCHER_BLOCK_SIZE
(
float
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_
V2
_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
CALL_
V2
_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
...
...
@@ -522,3 +863,4 @@ void single_query_cached_kv_attention(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
csrc/attention/dtype_bfloat16.cuh
View file @
928de468
...
...
@@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif
}
// From bfloat16 to float32.
inline
__device__
float
to_float
(
__nv_bfloat16
u
)
{
return
__bfloat162float
(
u
);
}
// Zero-out a variable.
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
tests/kernels/test_attention.py
View file @
928de468
...
...
@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
1
,
3
,
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
SEEDS
=
[
0
]
...
...
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_single_query_cached_kv_attention
(
def
test_paged_attention
(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
...
...
@@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention(
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
if
version
==
"v1"
:
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
assert
False
,
f
"Unknown version:
{
version
}
"
# Run the reference implementation.
ref_output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/layers/attention.py
View file @
928de468
...
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding
)
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttention
(
nn
.
Module
):
...
...
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
output
.
copy_
(
out
.
squeeze
(
0
))
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the slopes for the alibi attention bias.
Returns:
slopes: shape = [num_heads]
"""
return
None
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
...
...
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
"""PagedAttention for the generation tokens.
...
...
@@ -148,21 +159,65 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
alibi_slopes: shape = [num_heads]
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
None
,
# alibi_slopes
)
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1
=
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
if
use_v1
:
# Run PagedAttention V1.
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
def
forward
(
self
,
...
...
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
()
)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
...
...
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
start
+=
prompt_len
return
output
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
self
.
alibi_slopes
,
)
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
alibi_slopes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment