"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "5d194c5db0da8eeb591bfcf5e946b79e6efd6853"
Unverified Commit 0f4b3219 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Support various block sizes & Change default block size to 16 (#38)

parent 84eee24e
...@@ -268,6 +268,7 @@ if __name__ == '__main__': ...@@ -268,6 +268,7 @@ if __name__ == '__main__':
f'{model_name}-tp{args.tensor_parallel_size}', f'{model_name}-tp{args.tensor_parallel_size}',
sample_dir, sample_dir,
'cacheflow', 'cacheflow',
f'block{args.block_size}',
f'req-rate-{args.request_rate}', f'req-rate-{args.request_rate}',
f'seed{args.seed}', f'seed{args.seed}',
f'duration-{args.duration}', f'duration-{args.duration}',
......
...@@ -15,9 +15,6 @@ class BlockAllocator: ...@@ -15,9 +15,6 @@ class BlockAllocator:
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
) -> None: ) -> None:
if block_size not in [8, 16, 32]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be one of {8, 16, 32}.')
self.device = device self.device = device
self.block_size = block_size self.block_size = block_size
self.num_blocks = num_blocks self.num_blocks = num_blocks
......
...@@ -125,7 +125,8 @@ class Scheduler: ...@@ -125,7 +125,8 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible. # Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped) self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped: # FCFS
while self.swapped and not blocks_to_swap_out:
seq_group = self.swapped[0] seq_group = self.swapped[0]
# If the sequence group has been preempted in this step, stop. # If the sequence group has been preempted in this step, stop.
if seq_group in preempted: if seq_group in preempted:
......
...@@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser): ...@@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size') parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported. # NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
......
...@@ -11,25 +11,9 @@ void single_query_cached_kv_attention( ...@@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
int block_size, int block_size,
int max_context_len); int max_context_len);
void multi_query_cached_kv_attention(
torch::Tensor& cu_query_lens,
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"single_query_cached_kv_attention", "single_query_cached_kv_attention",
&single_query_cached_kv_attention, &single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors"); "Compute the attention between an input query and the cached key/value tensors");
m.def(
"multi_query_cached_kv_attention",
&multi_query_cached_kv_attention,
"Compute the attention between multiple input queries and the cached key/value tensors");
} }
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <algorithm> #include <algorithm>
#define WARP_SIZE 32 #define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace cacheflow { namespace cacheflow {
...@@ -27,7 +29,8 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -27,7 +29,8 @@ __global__ void single_query_cached_kv_attention_kernel(
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const int q_stride) { const int q_stride) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
...@@ -39,10 +42,10 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -39,10 +42,10 @@ __global__ void single_query_cached_kv_attention_kernel(
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread group
// fetch or comput 16 bytes at a time. // fetch or compute 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half, // For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2. // then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
...@@ -88,284 +91,42 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -88,284 +91,42 @@ __global__ void single_query_cached_kv_attention_kernel(
// dot product with the query. // dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx]; const int physical_block_number = block_table[block_idx];
const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on. // vectors of the key, and so on.
K_vec k_vecs[NUM_VECS_PER_THREAD]; for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
#pragma unroll const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE K_vec k_vecs[NUM_VECS_PER_THREAD];
+ head_idx * HEAD_SIZE * BLOCK_SIZE
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename FloatVec<V_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
accs[i] += dot(logits_vec, cast_to_float(v_vec));
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + head_idx * HEAD_SIZE * BLOCK_SIZE
dst[row_idx] = accs[i]; + physical_block_offset * x;
} const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} }
}
__syncthreads();
// Lower warps update the output. // Compute dot product.
if (warp_idx < mid) { // This includes a reduction across the threads in the same thread group.
const float* src = &out_smem[warp_idx * HEAD_SIZE]; const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
#pragma unroll const bool mask = token_idx >= context_len;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (thread_group_offset == 0) {
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { // Store the partial reductions to shared memory.
accs[i] += src[row_idx]; // NOTE(woosuk): It is required to zero out the masked logits.
} logits[token_idx] = mask ? 0.f : qk;
} // Update the max value.
} qk_max = mask ? qk_max : fmaxf(qk_max, qk);
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
convert_from_float(*(out_ptr + row_idx), accs[i]);
} }
} }
} }
}
// Grid: (num_heads, num_query_tokens).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int seq_start_idx,
const int seq_len,
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const float scale,
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
const int context_len,
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or comput 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 logits and accumulation.
float *logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx);
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= mask_boundary;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
...@@ -391,7 +152,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( ...@@ -391,7 +152,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
// Get the sum of the exp values. // Get the sum of the exp values.
float exp_sum = 0.f; float exp_sum = 0.f;
for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max); float val = __expf(logits[i] - qk_max);
logits[i] = val; logits[i] = val;
exp_sum += val; exp_sum += val;
...@@ -406,7 +167,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( ...@@ -406,7 +167,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
__syncthreads(); __syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time. // Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename FloatVec<V_vec>::Type; using L_vec = typename FloatVec<V_vec>::Type;
...@@ -499,46 +260,6 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( ...@@ -499,46 +260,6 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
} }
} }
// Grid: (num_heads, num_query_tokens).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__global__ void multi_query_cached_kv_attention_kernel(
const int* cu_query_lens, // [num_prompts+1]
const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const float scale,
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_prompts]
const int max_num_blocks_per_seq,
const int q_stride) {
const int seq_idx = blockIdx.y;
const int prompt_idx = seq_prompt_mapping[seq_idx];
const int seq_start_idx = cu_query_lens[prompt_idx];
const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx;
const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq;
const int context_len = context_lens[prompt_idx];
multi_query_cached_kv_attention_kernel_unoptimized_<
scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
out,
q,
seq_start_idx,
seq_len,
k_cache,
v_cache,
scale,
block_table,
context_len,
max_num_blocks_per_seq,
q_stride);
}
} // namespace cacheflow } // namespace cacheflow
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
...@@ -574,6 +295,9 @@ void single_query_cached_kv_attention_launcher( ...@@ -574,6 +295,9 @@ void single_query_cached_kv_attention_launcher(
int max_num_blocks_per_seq = block_tables.size(1); int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0); int query_stride = query.stride(0);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
...@@ -621,6 +345,17 @@ void single_query_cached_kv_attention_launcher( ...@@ -621,6 +345,17 @@ void single_query_cached_kv_attention_launcher(
} }
} }
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
scale, \
block_tables, \
context_lens, \
max_context_len);
void single_query_cached_kv_attention( void single_query_cached_kv_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
...@@ -634,285 +369,528 @@ void single_query_cached_kv_attention( ...@@ -634,285 +369,528 @@ void single_query_cached_kv_attention(
// TODO(woosuk): Support BF16. // TODO(woosuk): Support BF16.
if (query.element_size() == 2) { if (query.element_size() == 2) {
// Half. // Half.
if (block_size == 8) { if (block_size == 1) {
single_query_cached_kv_attention_launcher<uint16_t, 8>( CALL_KERNEL_LAUNCHER(uint16_t, 1);
out, } else if (block_size == 2) {
query, CALL_KERNEL_LAUNCHER(uint16_t, 2);
key_cache, } else if (block_size == 4) {
value_cache, CALL_KERNEL_LAUNCHER(uint16_t, 4);
scale, } else if (block_size == 8) {
block_tables, CALL_KERNEL_LAUNCHER(uint16_t, 8);
context_lens,
max_context_len);
} else if (block_size == 16) {
single_query_cached_kv_attention_launcher<uint16_t, 16>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else if (block_size == 32) {
single_query_cached_kv_attention_launcher<uint16_t, 32>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else {
assert(false);
}
} else if (query.element_size() == 4) {
// Float.
if (block_size == 8) {
single_query_cached_kv_attention_launcher<float, 8>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else if (block_size == 16) { } else if (block_size == 16) {
single_query_cached_kv_attention_launcher<float, 16>( CALL_KERNEL_LAUNCHER(uint16_t, 16);
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else if (block_size == 32) { } else if (block_size == 32) {
single_query_cached_kv_attention_launcher<float, 32>( CALL_KERNEL_LAUNCHER(uint16_t, 32);
out, } else if (block_size == 64) {
query, CALL_KERNEL_LAUNCHER(uint16_t, 64);
key_cache, } else if (block_size == 128) {
value_cache, CALL_KERNEL_LAUNCHER(uint16_t, 128);
scale, } else if (block_size == 256) {
block_tables, CALL_KERNEL_LAUNCHER(uint16_t, 256);
context_lens,
max_context_len);
} else { } else {
assert(false); assert(false);
} }
} else { } else {
// Float.
assert(false); assert(false);
} }
} }
// namespace cacheflow {
#define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cacheflow::multi_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \ // // Grid: (num_heads, num_query_tokens).
<<<grid, block, shared_mem_size, stream>>>( \ // template<
cu_query_lens_ptr, \ // typename scalar_t,
seq_prompt_mapping_ptr, \ // int HEAD_SIZE,
out_ptr, \ // int BLOCK_SIZE,
query_ptr, \ // int NUM_THREADS>
key_cache_ptr, \ // __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
value_cache_ptr, \ // scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
scale, \ // const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
block_tables_ptr, \ // const int seq_start_idx,
context_lens_ptr, \ // const int seq_len,
max_num_blocks_per_seq, \ // const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
query_stride); // const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
// const float scale,
// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
// TODO(woosuk): Tune NUM_THREADS. // const int context_len,
template< // const int max_num_blocks_per_seq,
typename T, // const int q_stride) {
int BLOCK_SIZE, // constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
int NUM_THREADS = 128> // constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
void multi_query_cached_kv_attention_launcher( // const int thread_idx = threadIdx.x;
torch::Tensor& cu_query_lens, // const int warp_idx = thread_idx / WARP_SIZE;
torch::Tensor& seq_prompt_mapping, // const int lane = thread_idx % WARP_SIZE;
torch::Tensor& out,
torch::Tensor& query, // const int head_idx = blockIdx.x;
torch::Tensor& key_cache, // const int num_heads = gridDim.x;
torch::Tensor& value_cache, // const int seq_idx = blockIdx.y;
float scale,
torch::Tensor& block_tables, // // A vector type to store a part of a key or a query.
torch::Tensor& context_lens, // // The vector size is configured in such a way that the threads in a thread group
int max_context_len) { // // fetch or comput 16 bytes at a time.
int num_seqs = query.size(0); // // For example, if the size of a thread group is 4 and the data type is half,
int num_heads = query.size(1); // // then the vector size is 16 / (4 * sizeof(half)) == 2.
int head_size = query.size(2); // constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
int max_num_blocks_per_seq = block_tables.size(1); // using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
int query_stride = query.stride(0); // using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>(); // constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>(); // constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); // const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); // const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); // // Load the query to registers.
int* context_lens_ptr = context_lens.data_ptr<int>(); // // Each thread in a thread group has a different part of the query.
// // For example, if the the thread group size is 4, then the first thread in the group
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; // // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; // // th vectors of the query, and so on.
int logits_size = padded_max_context_len * sizeof(float); // // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
int shared_mem_size = std::max(logits_size, outputs_size); // Q_vec q_vecs[NUM_VECS_PER_THREAD];
// #pragma unroll
dim3 grid(num_heads, num_seqs); // for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
dim3 block(NUM_THREADS); // const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
switch (head_size) { // }
case 32:
LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); // // Memory planning.
break; // extern __shared__ char shared_mem[];
case 64: // // NOTE(woosuk): We use FP32 logits and accumulation.
LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); // float *logits = reinterpret_cast<float*>(shared_mem);
break; // // Workspace for reduction.
case 80: // __shared__ float red_smem[2 * NUM_WARPS];
LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
break; // // x == THREAD_GROUP_SIZE * VEC_SIZE
case 96: // // Each thread group fetches x elements from the key at a time.
LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); // constexpr int x = 16 / sizeof(scalar_t);
break; // float qk_max = -FLT_MAX;
case 128:
LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); // const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
break; // const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx);
case 160:
LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); // // Iterate over the key blocks.
break; // // Each warp fetches a block of keys for each iteration.
case 192: // // Each thread group in a warp fetches a key from the block, and computes
LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); // // dot product with the query.
break; // for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
case 256: // const int physical_block_number = block_table[block_idx];
LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); // const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
break; // const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
default:
assert(false); // // Load a key to registers.
break; // // Each thread in a thread group has a different part of the key.
} // // For example, if the the thread group size is 4, then the first thread in the group
} // // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// // vectors of the key, and so on.
void multi_query_cached_kv_attention( // K_vec k_vecs[NUM_VECS_PER_THREAD];
torch::Tensor& cu_query_lens, // #pragma unroll
torch::Tensor& out, // for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
torch::Tensor& query, // const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
torch::Tensor& key_cache, // + head_idx * HEAD_SIZE * BLOCK_SIZE
torch::Tensor& value_cache, // + physical_block_offset * x;
float scale, // const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
torch::Tensor& block_tables, // const int offset1 = (vec_idx * VEC_SIZE) / x;
torch::Tensor& context_lens, // const int offset2 = (vec_idx * VEC_SIZE) % x;
int block_size, // k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
int max_context_len) { // }
torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); // // Compute dot product.
// // This includes a reduction across the threads in the same thread group.
// const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
// const bool mask = token_idx >= mask_boundary;
// if (thread_group_offset == 0) {
// // Store the partial reductions to shared memory.
// // NOTE(woosuk): It is required to zero out the masked logits.
// logits[token_idx] = mask ? 0.f : qk;
// // Update the max value.
// qk_max = mask ? qk_max : fmaxf(qk_max, qk);
// }
// }
// // Perform reduction across the threads in the same warp to get the
// // max qk value for each "warp" (not across the thread block yet).
// // The 0-th thread of each thread group already has its max qk value.
// #pragma unroll
// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
// }
// if (lane == 0) {
// red_smem[warp_idx] = qk_max;
// }
// __syncthreads();
// // TODO(woosuk): Refactor this part.
// // Get the max qk value for the sequence.
// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
// #pragma unroll
// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
// }
// // Broadcast the max qk value to all threads.
// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// // Get the sum of the exp values.
// float exp_sum = 0.f;
// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) {
// float val = __expf(logits[i] - qk_max);
// logits[i] = val;
// exp_sum += val;
// }
// exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// // Compute softmax.
// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
// for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
// logits[i] *= inv_sum;
// }
// __syncthreads();
// // Each thread will fetch 16 bytes from the value cache at a time.
// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
// using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
// using L_vec = typename FloatVec<V_vec>::Type;
// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
// float accs[NUM_ROWS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// accs[i] = 0.f;
// }
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
// const int physical_block_number = block_table[block_idx];
// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
// + head_idx * HEAD_SIZE * BLOCK_SIZE;
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE) {
// const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
// V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
// accs[i] += dot(logits_vec, cast_to_float(v_vec));
// }
// }
// }
// // Perform reduction within each warp.
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// float acc = accs[i];
// #pragma unroll
// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
// acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
// }
// accs[i] = acc;
// }
// // NOTE(woosuk): A barrier is required because the shared memory space for logits
// // is reused for the output.
// __syncthreads();
// // Perform reduction across warps.
// float* out_smem = reinterpret_cast<float*>(shared_mem);
// #pragma unroll
// for (int i = NUM_WARPS; i > 1; i /= 2) {
// int mid = i / 2;
// // Upper warps write to shared memory.
// if (warp_idx >= mid && warp_idx < i) {
// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// dst[row_idx] = accs[i];
// }
// }
// }
// __syncthreads();
// // Lower warps update the output.
// if (warp_idx < mid) {
// const float* src = &out_smem[warp_idx * HEAD_SIZE];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// accs[i] += src[row_idx];
// }
// }
// }
// __syncthreads();
// }
// // Write the final output.
// if (warp_idx == 0) {
// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// convert_from_float(*(out_ptr + row_idx), accs[i]);
// }
// }
// }
// }
// // Grid: (num_heads, num_query_tokens).
// template<
// typename scalar_t,
// int HEAD_SIZE,
// int BLOCK_SIZE,
// int NUM_THREADS>
// __global__ void multi_query_cached_kv_attention_kernel(
// const int* cu_query_lens, // [num_prompts+1]
// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
// const float scale,
// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
// const int* __restrict__ context_lens, // [num_prompts]
// const int max_num_blocks_per_seq,
// const int q_stride) {
// const int seq_idx = blockIdx.y;
// const int prompt_idx = seq_prompt_mapping[seq_idx];
// const int seq_start_idx = cu_query_lens[prompt_idx];
// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx;
// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq;
// const int context_len = context_lens[prompt_idx];
// multi_query_cached_kv_attention_kernel_unoptimized_<
// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
// out,
// q,
// seq_start_idx,
// seq_len,
// k_cache,
// v_cache,
// scale,
// block_table,
// context_len,
// max_num_blocks_per_seq,
// q_stride);
// }
// } // namespace cacheflow
// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
// cacheflow::multi_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
// <<<grid, block, shared_mem_size, stream>>>( \
// cu_query_lens_ptr, \
// seq_prompt_mapping_ptr, \
// out_ptr, \
// query_ptr, \
// key_cache_ptr, \
// value_cache_ptr, \
// scale, \
// block_tables_ptr, \
// context_lens_ptr, \
// max_num_blocks_per_seq, \
// query_stride);
// // TODO(woosuk): Tune NUM_THREADS.
// template<
// typename T,
// int BLOCK_SIZE,
// int NUM_THREADS = 128>
// void multi_query_cached_kv_attention_launcher(
// torch::Tensor& cu_query_lens,
// torch::Tensor& seq_prompt_mapping,
// torch::Tensor& out,
// torch::Tensor& query,
// torch::Tensor& key_cache,
// torch::Tensor& value_cache,
// float scale,
// torch::Tensor& block_tables,
// torch::Tensor& context_lens,
// int max_context_len) {
// int num_seqs = query.size(0);
// int num_heads = query.size(1);
// int head_size = query.size(2);
// int max_num_blocks_per_seq = block_tables.size(1);
// int query_stride = query.stride(0);
// int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
// T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
// T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
// T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
// T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
// int* block_tables_ptr = block_tables.data_ptr<int>();
// int* context_lens_ptr = context_lens.data_ptr<int>();
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
// int logits_size = padded_max_context_len * sizeof(float);
// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// int shared_mem_size = std::max(logits_size, outputs_size);
// dim3 grid(num_heads, num_seqs);
// dim3 block(NUM_THREADS);
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// switch (head_size) {
// case 32:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
// case 64:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
// break;
// case 80:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
// break;
// case 96:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
// break;
// case 128:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
// break;
// case 160:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
// case 256:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break;
// default:
// assert(false);
// break;
// }
// }
// void multi_query_cached_kv_attention(
// torch::Tensor& cu_query_lens,
// torch::Tensor& out,
// torch::Tensor& query,
// torch::Tensor& key_cache,
// torch::Tensor& value_cache,
// float scale,
// torch::Tensor& block_tables,
// torch::Tensor& context_lens,
// int block_size,
// int max_context_len) {
// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU);
int num_queries = query_lens.size(0) - 1; // int num_queries = query_lens.size(0) - 1;
const int* query_lens_ptr = query_lens.data_ptr<int>(); // const int* query_lens_ptr = query_lens.data_ptr<int>();
int num_seqs = query.size(0); // int num_seqs = query.size(0);
torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); // torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32));
auto accessor = cpu_tensor.accessor<int32_t, 1>(); // auto accessor = cpu_tensor.accessor<int32_t, 1>();
for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { // for (int i = 0, query_cursor = 0; i < num_seqs; ++i) {
if (i >= query_lens_ptr[query_cursor + 1]) { // if (i >= query_lens_ptr[query_cursor + 1]) {
++query_cursor; // ++query_cursor;
} // }
accessor[i] = query_cursor; // accessor[i] = query_cursor;
} // }
// TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) // // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA)
// implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving // // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving
// the mapping as an input parameter. Let's do this optimization in a later PR. // // the mapping as an input parameter. Let's do this optimization in a later PR.
torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); // torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA);
// TODO(woosuk): Support BF16. // // TODO(woosuk): Support BF16.
if (query.element_size() == 2) { // if (query.element_size() == 2) {
// Half. // // Half.
if (block_size == 8) { // if (block_size == 8) {
multi_query_cached_kv_attention_launcher<uint16_t, 8>( // multi_query_cached_kv_attention_launcher<uint16_t, 8>(
cu_query_lens, // cu_query_lens,
seq_prompt_mapping, // seq_prompt_mapping,
out, // out,
query, // query,
key_cache, // key_cache,
value_cache, // value_cache,
scale, // scale,
block_tables, // block_tables,
context_lens, // context_lens,
max_context_len); // max_context_len);
} else if (block_size == 16) { // } else if (block_size == 16) {
multi_query_cached_kv_attention_launcher<uint16_t, 16>( // multi_query_cached_kv_attention_launcher<uint16_t, 16>(
cu_query_lens, // cu_query_lens,
seq_prompt_mapping, // seq_prompt_mapping,
out, // out,
query, // query,
key_cache, // key_cache,
value_cache, // value_cache,
scale, // scale,
block_tables, // block_tables,
context_lens, // context_lens,
max_context_len); // max_context_len);
} else if (block_size == 32) { // } else if (block_size == 32) {
multi_query_cached_kv_attention_launcher<uint16_t, 32>( // multi_query_cached_kv_attention_launcher<uint16_t, 32>(
cu_query_lens, // cu_query_lens,
seq_prompt_mapping, // seq_prompt_mapping,
out, // out,
query, // query,
key_cache, // key_cache,
value_cache, // value_cache,
scale, // scale,
block_tables, // block_tables,
context_lens, // context_lens,
max_context_len); // max_context_len);
} else { // } else {
assert(false); // assert(false);
} // }
} else if (query.element_size() == 4) { // } else if (query.element_size() == 4) {
// Float. // // Float.
if (block_size == 8) { // if (block_size == 8) {
multi_query_cached_kv_attention_launcher<float, 8>( // multi_query_cached_kv_attention_launcher<float, 8>(
cu_query_lens, // cu_query_lens,
seq_prompt_mapping, // seq_prompt_mapping,
out, // out,
query, // query,
key_cache, // key_cache,
value_cache, // value_cache,
scale, // scale,
block_tables, // block_tables,
context_lens, // context_lens,
max_context_len); // max_context_len);
} else if (block_size == 16) { // } else if (block_size == 16) {
multi_query_cached_kv_attention_launcher<float, 16>( // multi_query_cached_kv_attention_launcher<float, 16>(
cu_query_lens, // cu_query_lens,
seq_prompt_mapping, // seq_prompt_mapping,
out, // out,
query, // query,
key_cache, // key_cache,
value_cache, // value_cache,
scale, // scale,
block_tables, // block_tables,
context_lens, // context_lens,
max_context_len); // max_context_len);
} else if (block_size == 32) { // } else if (block_size == 32) {
multi_query_cached_kv_attention_launcher<float, 32>( // multi_query_cached_kv_attention_launcher<float, 32>(
cu_query_lens, // cu_query_lens,
seq_prompt_mapping, // seq_prompt_mapping,
out, // out,
query, // query,
key_cache, // key_cache,
value_cache, // value_cache,
scale, // scale,
block_tables, // block_tables,
context_lens, // context_lens,
max_context_len); // max_context_len);
} else { // } else {
assert(false); // assert(false);
} // }
} else { // } else {
assert(false); // assert(false);
} // }
} // }
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX
#undef MIN
...@@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v) ...@@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float dot(float a, float b)
{
return a * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float dot(float2 a, float2 b)
{
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float dot(Float4_ a, Float4_ b) inline __device__ float dot(Float4_ a, Float4_ b)
{ {
float2 acc = mul<float2, float2, float2>(a.x, b.x); float2 acc = mul<float2, float2, float2>(a.x, b.x);
...@@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u) ...@@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float cast_to_float(float u) // inline __device__ float cast_to_float(float u)
{ // {
return u; // return u;
} // }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(float2 u) // inline __device__ float2 cast_to_float(float2 u)
{ // {
return u; // return u;
} // }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 cast_to_float(float4 u) // inline __device__ float4 cast_to_float(float4 u)
{ // {
return u; // return u;
} // }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(Float4_ u) // inline __device__ Float4_ cast_to_float(Float4_ u)
{ // {
return u; // return u;
} // }
////////////////////////////////////////////////////////////////////////////////////////////////////
// inline __device__ Float8_ cast_to_float(Float8_ u)
// {
// return u;
// }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(Float8_ u) inline __device__ float cast_to_float(uint16_t u)
{ {
return u; return half_to_float(u);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment