Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
928de468
Unverified
Commit
928de468
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Implement PagedAttention V2 (#1348)
parent
29678cd2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
764 additions
and
139 deletions
+764
-139
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+197
-0
csrc/attention.cpp
csrc/attention.cpp
+24
-4
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+413
-71
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+5
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+54
-17
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+71
-47
No files found.
benchmarks/kernels/benchmark_paged_attention.py
0 → 100644
View file @
928de468
import
argparse
import
random
import
time
import
torch
from
vllm
import
attention_ops
NUM_BLOCKS
=
1024
PARTITION_SIZE
=
512
@
torch
.
inference_mode
()
def
main
(
version
:
str
,
num_seqs
:
int
,
context_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
do_profile
:
bool
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
context_lens
=
[
context_len
for
_
in
range
(
num_seqs
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the KV cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
key_cache
.
uniform_
(
-
scale
,
scale
)
value_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
value_cache
.
uniform_
(
-
scale
,
scale
)
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
def
run_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
elif
version
==
"v2"
:
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
(
num_iters
=
3
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
100
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--context-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
default
=
128
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num_query_heads
%
args
.
num_kv_heads
!=
0
:
raise
ValueError
(
"num_query_heads must be divisible by num_kv_heads"
)
dtype_to_torch_dtype
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
context_len
=
args
.
context_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
use_alibi
=
args
.
use_alibi
,
dtype
=
dtype_to_torch_dtype
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
)
csrc/attention.cpp
View file @
928de468
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/util/Optional.h>
#include <c10/util/Optional.h>
void
single_query_cached_kv
_attention
(
void
paged
_attention
_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
key_cache
,
...
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
...
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
int
max_context_len
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
m
.
def
(
"single_query_cached_kv_attention"
,
"paged_attention_v1"
,
&
single_query_cached_kv_attention
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached key/value tensors"
);
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
m
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
}
}
csrc/attention/attention_kernels.cu
View file @
928de468
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#define WARP_SIZE 32
#define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
namespace
vllm
{
...
@@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) {
...
@@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
}
// Grid: (num_heads, num_seqs).
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
template
<
typename
scalar_t
,
typename
scalar_t
,
int
HEAD_SIZE
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
>
int
NUM_THREADS
,
__global__
void
single_query_cached_kv_attention_kernel
(
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
...
@@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
q_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
const
int
kv_head_stride
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
context_len
=
context_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
context_len
)
{
// No work to do. Terminate the thread block.
return
;
}
const
int
num_context_blocks
=
DIVIDE_ROUND_UP
(
context_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_context_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_context_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
context_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
(
BLOCK_SIZE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
)
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
...
@@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
kv_head_idx
=
head_mapping
[
head_idx
];
const
int
kv_head_idx
=
head_mapping
[
head_idx
];
const
int
seq_idx
=
blockIdx
.
y
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
// A vector type to store a part of a key or a query.
...
@@ -142,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -142,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel(
constexpr
int
x
=
16
/
sizeof
(
scalar_t
);
constexpr
int
x
=
16
/
sizeof
(
scalar_t
);
float
qk_max
=
-
FLT_MAX
;
float
qk_max
=
-
FLT_MAX
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_blocks
=
(
context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
// Iterate over the key blocks.
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
// dot product with the query.
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num_blocks
;
block_idx
+=
NUM_WARPS
)
{
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_number
=
block_table
[
block_idx
];
// Load a key to registers.
// Load a key to registers.
...
@@ -184,7 +208,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -184,7 +208,7 @@ __global__ void single_query_cached_kv_attention_kernel(
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
context_len
;
const
bool
mask
=
token_idx
>=
context_len
;
logits
[
token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
}
...
@@ -215,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -215,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel(
// Get the sum of the exp values.
// Get the sum of the exp values.
float
exp_sum
=
0.
f
;
float
exp_sum
=
0.
f
;
for
(
int
i
=
thread_idx
;
i
<
context_l
en
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tok
en
s
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
logits
[
i
]
=
val
;
exp_sum
+=
val
;
exp_sum
+=
val
;
...
@@ -224,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -224,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel(
// Compute softmax.
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
context_l
en
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tok
en
s
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
logits
[
i
]
*=
inv_sum
;
}
}
__syncthreads
();
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
}
// Each thread will fetch 16 bytes from the value cache at a time.
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
...
@@ -237,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -237,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel(
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
(
HEAD_SIZE
+
NUM_ROWS_PER_ITER
-
1
)
/
NUM_ROWS_PER_ITER
;
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
)
;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
float
accs
[
NUM_ROWS_PER_THREAD
];
...
@@ -248,12 +284,12 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -248,12 +284,12 @@ __global__ void single_query_cached_kv_attention_kernel(
scalar_t
zero_value
;
scalar_t
zero_value
;
zero
(
zero_value
);
zero
(
zero_value
);
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num
_block
s
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end
_block
_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
));
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
start_token_idx
));
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
+
kv_head_idx
*
kv_head_stride
;
...
@@ -263,7 +299,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -263,7 +299,7 @@ __global__ void single_query_cached_kv_attention_kernel(
if
(
row_idx
<
HEAD_SIZE
)
{
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
V_vec
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
if
(
block_idx
==
num_blocks
-
1
)
{
if
(
block_idx
==
num_
context_
blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
...
@@ -327,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -327,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel(
// Write the final output.
// Write the final output.
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
...
@@ -338,13 +376,167 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -338,13 +376,167 @@ __global__ void single_query_cached_kv_attention_kernel(
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
>
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
*
__restrict__
head_mapping
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
paged_attention_kernel
<
scalar_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
*
__restrict__
head_mapping
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
paged_attention_kernel
<
scalar_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
);
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
}
// Terminate the thread block.
return
;
}
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Size: 2 * num_partitions.
extern
__shared__
char
shared_mem
[];
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Load max logits to shared memory.
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
const
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
max_logit
=
-
FLT_MAX
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
const
float
l
=
max_logits_ptr
[
i
];
shared_max_logits
[
i
]
=
l
;
max_logit
=
fmaxf
(
max_logit
,
l
);
}
__syncthreads
();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
max_logit
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
max_logit
;
}
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
max_logit
,
mask
));
}
// Broadcast the max value to all threads.
max_logit
=
__shfl_sync
(
uint32_t
(
-
1
),
max_logit
,
0
);
// Load rescaled exp sums to shared memory.
float
*
shared_exp_sums
=
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
global_exp_sum
=
0.0
f
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
l
=
shared_max_logits
[
i
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
global_exp_sum
+=
rescaled_exp_sum
;
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
}
__syncthreads
();
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Aggregate tmp_out to out.
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_ATTENTION_
KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS)
\
#define LAUNCH_
PAGED_
ATTENTION_
V1(HEAD_SIZE)
\
cudaFuncSetAttribute( \
cudaFuncSetAttribute( \
vllm::
single_query_cached_kv
_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
vllm::
paged
_attention_
v1_
kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>,
\
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);
\
vllm::
single_query_cached_kv
_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
vllm::
paged
_attention_
v1_
kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>
\
<<<grid, block, shared_mem_size, stream>>>( \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
out_ptr, \
query_ptr, \
query_ptr, \
...
@@ -365,7 +557,7 @@ template<
...
@@ -365,7 +557,7 @@ template<
typename
T
,
typename
T
,
int
BLOCK_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
=
128
>
int
NUM_THREADS
=
128
>
void
single_query_cached_kv
_attention_launcher
(
void
paged
_attention_
v1_
launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
key_cache
,
...
@@ -401,45 +593,206 @@ void single_query_cached_kv_attention_launcher(
...
@@ -401,45 +593,206 @@ void single_query_cached_kv_attention_launcher(
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_context_len
=
(
(
max_context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_context_len
=
DIVIDE_ROUND_UP
(
max_context_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
// Keep that in sync with the logic here!
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
dim3
grid
(
num_heads
,
num_seqs
);
dim3
grid
(
num_heads
,
num_seqs
,
1
);
dim3
block
(
NUM_THREADS
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case
64
:
LAUNCH_PAGED_ATTENTION_V1
(
64
);
break
;
case
80
:
LAUNCH_PAGED_ATTENTION_V1
(
80
);
break
;
case
96
:
LAUNCH_PAGED_ATTENTION_V1
(
96
);
break
;
case
112
:
LAUNCH_PAGED_ATTENTION_V1
(
112
);
break
;
case
128
:
LAUNCH_PAGED_ATTENTION_V1
(
128
);
break
;
case
256
:
LAUNCH_PAGED_ATTENTION_V1
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
break
;
}
}
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, 8); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, 16); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, 32); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
head_mapping
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
float
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
context_lens_ptr, \
max_num_partitions);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_THREADS
=
128
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
float
*
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
float
*
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
tmp_out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
head_mapping_ptr
=
reinterpret_cast
<
int
*>
(
head_mapping
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// For paged attention v2 kernel.
dim3
grid
(
num_heads
,
num_seqs
,
max_num_partitions
);
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
// For paged attention v2 reduce kernel.
dim3
reduce_grid
(
num_heads
,
num_seqs
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// 32, 160, 192.
// head sizes that we use in the model. However, we can easily extend this
// case 32:
// to support any head size which is a multiple of 16.
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
case
64
:
case
64
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
64
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
64
);
break
;
break
;
case
80
:
case
80
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
80
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
80
);
break
;
break
;
case
96
:
case
96
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
96
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
96
);
break
;
break
;
case
112
:
case
112
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
112
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
112
);
break
;
break
;
case
128
:
case
128
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
128
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
128
);
break
;
break
;
// case 160:
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
case
256
:
case
256
:
LAUNCH_ATTENTION_
KERNEL
(
T
,
256
,
BLOCK_SIZE
,
NUM_THREADS
);
LAUNCH_
PAGED_
ATTENTION_
V2
(
256
);
break
;
break
;
default:
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
...
@@ -447,9 +800,12 @@ void single_query_cached_kv_attention_launcher(
...
@@ -447,9 +800,12 @@ void single_query_cached_kv_attention_launcher(
}
}
}
}
#define CALL_
KERNEL
_LAUNCHER(T, BLOCK_SIZE) \
#define CALL_
V2
_LAUNCHER(T, BLOCK_SIZE)
\
single_query_cached_kv
_attention_launcher<T, BLOCK_SIZE>( \
paged
_attention_
v2_
launcher<T, BLOCK_SIZE>(
\
out, \
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
query, \
key_cache, \
key_cache, \
value_cache, \
value_cache, \
...
@@ -462,42 +818,27 @@ void single_query_cached_kv_attention_launcher(
...
@@ -462,42 +818,27 @@ void single_query_cached_kv_attention_launcher(
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
// 1, 2, 4, 64, 128, 256.
#define CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE(T) \
#define CALL_
V2
_LAUNCHER_BLOCK_SIZE(T)
\
switch (block_size) { \
switch (block_size) { \
/* case 1: */
\
/* CALL_KERNEL_LAUNCHER(T, 1); */
\
/* break; */
\
/* case 2: */
\
/* CALL_KERNEL_LAUNCHER(T, 2); */
\
/* break; */
\
/* case 4: */
\
/* CALL_KERNEL_LAUNCHER(T, 4); */
\
/* break; */
\
case 8: \
case 8: \
CALL_
KERNEL
_LAUNCHER(T, 8); \
CALL_
V2
_LAUNCHER(T, 8);
\
break; \
break; \
case 16: \
case 16: \
CALL_
KERNEL
_LAUNCHER(T, 16); \
CALL_
V2
_LAUNCHER(T, 16);
\
break; \
break; \
case 32: \
case 32: \
CALL_
KERNEL
_LAUNCHER(T, 32); \
CALL_
V2
_LAUNCHER(T, 32);
\
break; \
break; \
/* case 64: */
\
/* CALL_KERNEL_LAUNCHER(T, 64); */
\
/* break; */
\
/* case 128: */
\
/* CALL_KERNEL_LAUNCHER(T, 128); */
\
/* break; */
\
/* case 256: */
\
/* CALL_KERNEL_LAUNCHER(T, 256); */
\
/* break; */
\
default: \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
break; \
}
}
void
single_query_cached_kv
_attention
(
void
paged
_attention
_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
...
@@ -509,11 +850,11 @@ void single_query_cached_kv_attention(
...
@@ -509,11 +850,11 @@ void single_query_cached_kv_attention(
int
max_context_len
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE
(
float
);
CALL_
V2
_LAUNCHER_BLOCK_SIZE
(
float
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_
V2
_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_
KERNEL
_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
CALL_
V2
_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
@@ -522,3 +863,4 @@ void single_query_cached_kv_attention(
...
@@ -522,3 +863,4 @@ void single_query_cached_kv_attention(
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
#undef DIVIDE_ROUND_UP
csrc/attention/dtype_bfloat16.cuh
View file @
928de468
...
@@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
...
@@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif
#endif
}
}
// From bfloat16 to float32.
inline
__device__
float
to_float
(
__nv_bfloat16
u
)
{
return
__bfloat162float
(
u
);
}
// Zero-out a variable.
// Zero-out a variable.
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
tests/kernels/test_attention.py
View file @
928de468
...
@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
...
@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# - 512 as a buffer
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
NUM_BLOCKS
=
128
# Arbitrary values for testing
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
1
,
3
,
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
USE_ALIBI
=
[
False
,
True
]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
...
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
...
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
...
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_paged_attention
(
def
test_single_query_cached_kv_attention
(
kv_cache_factory
,
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
...
@@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention(
...
@@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention(
# Call the paged attention kernel.
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
attention_ops
.
single_query_cached_kv_attention
(
if
version
==
"v1"
:
output
,
attention_ops
.
paged_attention_v1
(
query
,
output
,
key_cache
,
query
,
value_cache
,
key_cache
,
head_mapping
,
value_cache
,
scale
,
head_mapping
,
block_tables
,
scale
,
context_lens
,
block_tables
,
block_size
,
context_lens
,
max_context_len
,
block_size
,
alibi_slopes
,
max_context_len
,
)
alibi_slopes
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
assert
False
,
f
"Unknown version:
{
version
}
"
# Run the reference implementation.
# Run the reference implementation.
ref_output
=
torch
.
empty_like
(
query
)
ref_output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/layers/attention.py
View file @
928de468
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding
)
RotaryEmbedding
)
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttention
(
nn
.
Module
):
class
PagedAttention
(
nn
.
Module
):
...
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
...
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
output
.
copy_
(
out
.
squeeze
(
0
))
output
.
copy_
(
out
.
squeeze
(
0
))
return
output
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the slopes for the alibi attention bias.
Returns:
slopes: shape = [num_heads]
"""
return
None
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
...
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
...
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
)
->
None
:
"""PagedAttention for the generation tokens.
"""PagedAttention for the generation tokens.
...
@@ -148,21 +159,65 @@ class PagedAttention(nn.Module):
...
@@ -148,21 +159,65 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
block_size]
input_metadata: metadata for paged attention.
input_metadata: metadata for paged attention.
alibi_slopes: shape = [num_heads]
"""
"""
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
num_seqs
,
num_heads
,
head_size
=
query
.
shape
output
,
max_num_partitions
=
(
query
,
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
key_cache
,
_PARTITION_SIZE
)
value_cache
,
# NOTE(woosuk): We use a simple heuristic to decide whether to use
self
.
head_mapping
,
# PagedAttention V1 or V2. If the number of partitions is 1, we use
self
.
scale
,
# V1 to avoid the overhead of reduction. Also, if the number of
input_metadata
.
block_tables
,
# sequences or heads is large, we use V1 since there is enough work
input_metadata
.
context_lens
,
# to parallelize.
block_size
,
# TODO(woosuk): Tune this heuristic.
input_metadata
.
max_context_len
,
use_v1
=
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
None
,
# alibi_slopes
if
use_v1
:
)
# Run PagedAttention V1.
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
...
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
self
.
single_query_cached_kv_attention
(
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
()
)
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
# NOTE(woosuk): The output tensor may include paddings.
...
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
start
+=
prompt_len
start
+=
prompt_len
return
output
return
output
def
single_query_cached_kv_attention
(
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
self
,
return
self
.
alibi_slopes
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
self
.
alibi_slopes
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment