Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
21b3671b
Unverified
Commit
21b3671b
authored
Apr 04, 2023
by
Siyuan (Ryans) Zhuang
Committed by
GitHub
Apr 04, 2023
Browse files
Basic attention kernel that supports cached KV + (multi-)prompts (#24)
parent
897cb2ae
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
622 additions
and
0 deletions
+622
-0
csrc/attention.cpp
csrc/attention.cpp
+16
-0
csrc/attention_kernels.cu
csrc/attention_kernels.cu
+463
-0
tests/kernels/attention.py
tests/kernels/attention.py
+143
-0
No files found.
csrc/attention.cpp
View file @
21b3671b
...
...
@@ -11,9 +11,25 @@ void single_query_cached_kv_attention(
int
block_size
,
int
max_context_len
);
void
multi_query_cached_kv_attention
(
torch
::
Tensor
&
cu_query_lens
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_cached_kv_attention"
,
&
single_query_cached_kv_attention
,
"Compute the attention between an input query and the cached key/value tensors"
);
m
.
def
(
"multi_query_cached_kv_attention"
,
&
multi_query_cached_kv_attention
,
"Compute the attention between multiple input queries and the cached key/value tensors"
);
}
csrc/attention_kernels.cu
View file @
21b3671b
...
...
@@ -254,6 +254,287 @@ __global__ void single_query_cached_kv_attention_kernel(
}
}
// Grid: (num_heads, num_query_tokens).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
>
__device__
void
multi_query_cached_kv_attention_kernel_unoptimized_
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
int
seq_start_idx
,
const
int
seq_len
,
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
float
scale
,
const
int
*
__restrict__
block_table
,
// [num_seqs, max_num_blocks_per_seq]
const
int
context_len
,
const
int
max_num_blocks_per_seq
)
{
constexpr
int
THREAD_GROUP_SIZE
=
WARP_SIZE
/
BLOCK_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or comput 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr
int
VEC_SIZE
=
16
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
));
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
constexpr
int
NUM_ELEMS_PER_THREAD
=
HEAD_SIZE
/
THREAD_GROUP_SIZE
;
constexpr
int
NUM_VECS_PER_THREAD
=
NUM_ELEMS_PER_THREAD
/
VEC_SIZE
;
const
int
thread_group_idx
=
thread_idx
/
THREAD_GROUP_SIZE
;
const
int
thread_group_offset
=
thread_idx
%
THREAD_GROUP_SIZE
;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
Q_vec
q_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_VECS_PER_THREAD
;
i
++
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
// Memory planning.
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 logits and accumulation.
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr
int
x
=
16
/
sizeof
(
scalar_t
);
float
qk_max
=
-
FLT_MAX
;
const
int
num_blocks
=
(
context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
const
int
mask_boundary
=
context_len
-
seq_len
+
1
+
(
seq_idx
-
seq_start_idx
);
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num_blocks
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_offset
=
thread_group_idx
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_VECS_PER_THREAD
;
i
++
)
{
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
num_heads
*
HEAD_SIZE
*
BLOCK_SIZE
+
head_idx
*
HEAD_SIZE
*
BLOCK_SIZE
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
k_vecs
[
i
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
,
k_vecs
);
const
bool
mask
=
token_idx
>=
mask_boundary
;
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits
[
token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
qk_max
;
}
__syncthreads
();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Broadcast the max qk value to all threads.
qk_max
=
__shfl_sync
(
uint32_t
(
-
1
),
qk_max
,
0
);
// Get the sum of the exp values.
float
exp_sum
=
0.
f
;
for
(
int
i
=
thread_idx
;
i
<
mask_boundary
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
context_len
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
}
__syncthreads
();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
L_vec
=
typename
FloatVec
<
V_vec
>::
Type
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
(
HEAD_SIZE
+
NUM_ROWS_PER_ITER
-
1
)
/
NUM_ROWS_PER_ITER
;
float
accs
[
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
}
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num_blocks
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
=
*
reinterpret_cast
<
L_vec
*>
(
logits
+
token_idx
);
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
num_heads
*
HEAD_SIZE
*
BLOCK_SIZE
+
head_idx
*
HEAD_SIZE
*
BLOCK_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
accs
[
i
]
+=
dot
(
logits_vec
,
cast_to_float
(
v_vec
));
}
}
}
// Perform reduction within each warp.
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
acc
,
mask
);
}
accs
[
i
]
=
acc
;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads
();
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
i
];
}
}
}
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
i
]
+=
src
[
row_idx
];
}
}
}
__syncthreads
();
}
// Write the final output.
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
convert_from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
}
}
}
}
// Grid: (num_heads, num_query_tokens).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
>
__global__
void
multi_query_cached_kv_attention_kernel
(
const
int
*
cu_query_lens
,
// [num_prompts+1]
const
int
*
seq_prompt_mapping
,
// [num_seqs] mapping from seq_idx to prompt_idx
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_prompts, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_prompts]
const
int
max_num_blocks_per_seq
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
prompt_idx
=
seq_prompt_mapping
[
seq_idx
];
const
int
seq_start_idx
=
cu_query_lens
[
prompt_idx
];
const
int
seq_len
=
cu_query_lens
[
prompt_idx
+
1
]
-
seq_start_idx
;
const
int
*
block_table
=
block_tables
+
prompt_idx
*
max_num_blocks_per_seq
;
const
int
context_len
=
context_lens
[
prompt_idx
];
multi_query_cached_kv_attention_kernel_unoptimized_
<
scalar_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
>
(
out
,
q
,
seq_start_idx
,
seq_len
,
k_cache
,
v_cache
,
scale
,
block_table
,
context_len
,
max_num_blocks_per_seq
);
}
}
// namespace cacheflow
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
...
...
@@ -402,4 +683,186 @@ void single_query_cached_kv_attention(
}
}
#define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cacheflow::multi_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
cu_query_lens_ptr, \
seq_prompt_mapping_ptr, \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_THREADS
=
128
>
void
multi_query_cached_kv_attention_launcher
(
torch
::
Tensor
&
cu_query_lens
,
torch
::
Tensor
&
seq_prompt_mapping
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
max_context_len
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
*
cu_query_lens_ptr
=
cu_query_lens
.
data_ptr
<
int
>
();
int
*
seq_prompt_mapping_ptr
=
seq_prompt_mapping
.
data_ptr
<
int
>
();
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_context_len
=
((
max_context_len
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
logits_size
=
padded_max_context_len
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
dim3
grid
(
num_heads
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
case
32
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
32
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
64
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
64
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
80
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
80
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
96
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
96
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
128
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
160
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
160
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
192
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
192
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
256
:
LAUNCH_MULTI_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
default:
assert
(
false
);
break
;
}
}
void
multi_query_cached_kv_attention
(
torch
::
Tensor
&
cu_query_lens
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
)
{
torch
::
Tensor
query_lens
=
cu_query_lens
.
to
(
torch
::
kCPU
);
int
num_queries
=
query_lens
.
size
(
0
)
-
1
;
const
int
*
query_lens_ptr
=
query_lens
.
data_ptr
<
int
>
();
int
num_seqs
=
query
.
size
(
0
);
torch
::
Tensor
cpu_tensor
=
torch
::
empty
({
num_seqs
},
torch
::
dtype
(
torch
::
kInt32
));
auto
accessor
=
cpu_tensor
.
accessor
<
int32_t
,
1
>
();
for
(
int
i
=
0
,
query_cursor
=
0
;
i
<
num_seqs
;
++
i
)
{
if
(
i
>=
query_lens_ptr
[
query_cursor
+
1
])
{
++
query_cursor
;
}
accessor
[
i
]
=
query_cursor
;
}
// TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA)
// implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving
// the mapping as an input parameter. Let's do this optimization in a later PR.
torch
::
Tensor
seq_prompt_mapping
=
cpu_tensor
.
to
(
torch
::
kCUDA
);
// TODO(woosuk): Support BF16.
if
(
query
.
element_size
()
==
2
)
{
// Half.
if
(
block_size
==
8
)
{
multi_query_cached_kv_attention_launcher
<
uint16_t
,
8
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
16
)
{
multi_query_cached_kv_attention_launcher
<
uint16_t
,
16
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
}
else
if
(
query
.
element_size
()
==
4
)
{
// Float.
if
(
block_size
==
8
)
{
multi_query_cached_kv_attention_launcher
<
float
,
8
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
16
)
{
multi_query_cached_kv_attention_launcher
<
float
,
16
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
}
else
{
assert
(
false
);
}
}
#undef WARP_SIZE
tests/kernels/attention.py
View file @
21b3671b
...
...
@@ -97,6 +97,61 @@ def ref_multi_query_kv_attention(
return
ref_output
def
ref_multi_query_cached_kv_attention
(
cu_query_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
block_size
=
value_cache
.
shape
[
3
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_queries
=
len
(
cu_query_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_queries
):
start_idx
=
cu_query_lens
[
i
]
end_idx
=
cu_query_lens
[
i
+
1
]
query_len
=
end_idx
-
start_idx
context_len
=
int
(
context_lens
[
i
])
block_table
=
block_tables
[
i
]
# Create attention mask
attn_mask
=
torch
.
triu
(
torch
.
ones
(
query_len
,
context_len
),
diagonal
=
context_len
-
query_len
+
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
keys
=
[]
values
=
[]
for
j
in
range
(
context_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_heads
,
head_size
)
keys
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
keys
,
values
,
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
def
test_single_query_cached_kv_attention
(
num_tokens
:
int
,
num_heads
:
int
,
...
...
@@ -216,6 +271,76 @@ def test_multi_query_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_multi_query_cached_kv_attention
(
num_queries
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
query_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_queries
)
cu_query_lens
=
[
0
]
for
query_len
in
query_lens
:
cu_query_lens
.
append
(
cu_query_lens
[
-
1
]
+
query_len
)
num_total_tokens
=
cu_query_lens
[
-
1
]
query
=
torch
.
randn
(
num_total_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
cu_query_lens
=
torch
.
tensor
(
cu_query_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
context_lens
=
[
query_len
+
random
.
randint
(
0
,
MAX_SEQ_LEN
-
query_len
)
for
query_len
in
query_lens
]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_queries
):
block_table
=
[
random
.
randint
(
0
,
num_blocks
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
output
=
torch
.
empty_like
(
query
)
attention_ops
.
multi_query_cached_kv_attention
(
cu_query_lens
,
output
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
)
ref_output
=
ref_multi_query_cached_kv_attention
(
cu_query_lens
,
query
,
key_cache
,
value_cache
,
block_tables
,
context_lens
,
dtype
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
def
test_attention
(
seed
:
int
)
->
None
:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
...
...
@@ -237,6 +362,24 @@ def test_attention(seed: int) -> None:
dtype
=
dtype
,
)
# NOTE(siyuan): Same as above. Re-run the test if it fails. Also
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
test_multi_query_cached_kv_attention
(
num_queries
=
11
,
num_heads
=
3
,
head_size
=
head_size
,
block_size
=
block_size
,
num_blocks
=
1024
,
dtype
=
dtype
,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment