Unverified Commit 3521ba4f authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

parent 2d7bce9c
...@@ -16,7 +16,7 @@ PARTITION_SIZE = 512 ...@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def main( def main(
version: str, version: str,
num_seqs: int, num_seqs: int,
context_len: int, seq_len: int,
num_query_heads: int, num_query_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
...@@ -48,12 +48,12 @@ def main( ...@@ -48,12 +48,12 @@ def main(
dtype=torch.float, dtype=torch.float,
device=device) device=device)
context_lens = [context_len for _ in range(num_seqs)] seq_lens = [seq_len for _ in range(num_seqs)]
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
...@@ -77,8 +77,7 @@ def main( ...@@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
if version == "v2": if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
PARTITION_SIZE)
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
...@@ -110,9 +109,9 @@ def main( ...@@ -110,9 +109,9 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -129,9 +128,9 @@ def main( ...@@ -129,9 +128,9 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -166,7 +165,7 @@ if __name__ == '__main__': ...@@ -166,7 +165,7 @@ if __name__ == '__main__':
choices=["v1", "v2"], choices=["v1", "v2"],
default="v2") default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
...@@ -199,7 +198,7 @@ if __name__ == '__main__': ...@@ -199,7 +198,7 @@ if __name__ == '__main__':
main( main(
version=args.version, version=args.version,
num_seqs=args.batch_size, num_seqs=args.batch_size,
context_len=args.context_len, seq_len=args.seq_len,
num_query_heads=args.num_query_heads, num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads, num_kv_heads=args.num_kv_heads,
head_size=args.head_size, head_size=args.head_size,
......
...@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( ...@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride,
...@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel( ...@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block. // No work to do. Terminate the thread block.
return; return;
} }
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
...@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( ...@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len; const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
...@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel( ...@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
} else { } else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} }
if (block_idx == num_context_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs. // we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
accs[i] += dot(logits_vec, v_vec); accs[i] += dot(logits_vec, v_vec);
...@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride,
...@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, /* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
} }
...@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( ...@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride,
...@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( ...@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride, kv_scale); q_stride, kv_block_stride, kv_head_stride, kv_scale);
} }
...@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) { const int max_num_partitions) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
...@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ seq_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
alibi_slopes_ptr, \ alibi_slopes_ptr, \
q_stride, \ q_stride, \
...@@ -639,8 +639,8 @@ void paged_attention_v1_launcher( ...@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -664,11 +664,11 @@ void paged_attention_v1_launcher( ...@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float); int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here! // Keep that in sync with the logic here!
...@@ -715,8 +715,8 @@ void paged_attention_v1_launcher( ...@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ seq_lens, \
max_context_len, \ max_seq_len, \
alibi_slopes, \ alibi_slopes, \
kv_scale); kv_scale);
...@@ -746,9 +746,9 @@ void paged_attention_v1( ...@@ -746,9 +746,9 @@ void paged_attention_v1(
int num_kv_heads, // [num_heads] int num_kv_heads, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale) { float kv_scale) {
...@@ -790,7 +790,7 @@ void paged_attention_v1( ...@@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ seq_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
alibi_slopes_ptr, \ alibi_slopes_ptr, \
q_stride, \ q_stride, \
...@@ -803,7 +803,7 @@ void paged_attention_v1( ...@@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \ exp_sums_ptr, \
max_logits_ptr, \ max_logits_ptr, \
tmp_out_ptr, \ tmp_out_ptr, \
context_lens_ptr, \ seq_lens_ptr, \
max_num_partitions); max_num_partitions);
template< template<
...@@ -824,8 +824,8 @@ void paged_attention_v2_launcher( ...@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -852,10 +852,10 @@ void paged_attention_v2_launcher( ...@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float); int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
...@@ -909,8 +909,8 @@ void paged_attention_v2_launcher( ...@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ seq_lens, \
max_context_len, \ max_seq_len, \
alibi_slopes, \ alibi_slopes, \
kv_scale); kv_scale);
...@@ -943,9 +943,9 @@ void paged_attention_v2( ...@@ -943,9 +943,9 @@ void paged_attention_v2(
int num_kv_heads, // [num_heads] int num_kv_heads, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale) { float kv_scale) {
......
...@@ -70,11 +70,11 @@ template <typename T> ...@@ -70,11 +70,11 @@ template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity, reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope, const int start_index,
const int context_len) { const int seq_len) {
data[0] += alibi_slope * (start_index - context_len + 1); data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk; data[i] = qk;
max = max >= qk ? max : qk; max = max >= qk ? max : qk;
} }
...@@ -225,7 +225,7 @@ struct paged_attention_v1_impl { ...@@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
...@@ -235,32 +235,32 @@ struct paged_attention_v1_impl { ...@@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
static_assert(BLOCK_SIZE == 16); static_assert(BLOCK_SIZE == 16);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads(); const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_context_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc( float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int context_len = context_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int *seq_block_table = const int *seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num =
context_len - (block_num - 1) * BLOCK_SIZE; seq_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits = float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_context_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
...@@ -278,11 +278,11 @@ struct paged_attention_v1_impl { ...@@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
// Compute softmax // Compute softmax
if (alibi_slopes) { if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, context_len, reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
context_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, context_len, reduceSoftmax(thread_block_logits, seq_len,
block_num * BLOCK_SIZE); block_num * BLOCK_SIZE);
} }
...@@ -340,7 +340,7 @@ struct paged_attention_v1_impl { ...@@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
...@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE> ...@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, torch::Tensor &block_tables, torch::Tensor &seq_lens,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( ...@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
...@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( ...@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
...@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, ...@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor &key_cache, torch::Tensor &value_cache,
int num_kv_heads, float scale, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size, torch::Tensor &seq_lens, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
...@@ -448,7 +448,7 @@ struct paged_attention_v2_impl { ...@@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
...@@ -465,22 +465,22 @@ struct paged_attention_v2_impl { ...@@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
for (int partition_idx = 0; partition_idx < max_num_partitions; for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) { ++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= context_len) if (start_token_idx >= seq_len)
continue; continue;
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1); const bool no_reduce = (partition_num == 1);
const int context_token_num = const int token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num =
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
context_token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables + const int *seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
...@@ -507,10 +507,10 @@ struct paged_attention_v2_impl { ...@@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
std::pair<float, float> max_and_sum; std::pair<float, float> max_and_sum;
if (alibi_slopes) { if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi( max_and_sum = reduceSoftmaxAlibi(
logits, context_token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, context_token_num, max_and_sum = reduceSoftmax(logits, token_num,
block_num * BLOCK_SIZE); block_num * BLOCK_SIZE);
} }
...@@ -583,9 +583,9 @@ struct paged_attention_v2_impl { ...@@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1) #pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1)
continue; continue;
...@@ -612,9 +612,9 @@ struct paged_attention_v2_impl { ...@@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1)
continue; continue;
...@@ -649,7 +649,7 @@ struct paged_attention_v2_impl { ...@@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
...@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( ...@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( ...@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
...@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( ...@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, \
max_context_len, alibi_slopes); max_seq_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
...@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, ...@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables, float scale, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size, torch::Tensor &seq_lens, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
......
...@@ -10,9 +10,9 @@ void paged_attention_v1( ...@@ -10,9 +10,9 @@ void paged_attention_v1(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale); float kv_scale);
...@@ -28,9 +28,9 @@ void paged_attention_v2( ...@@ -28,9 +28,9 @@ void paged_attention_v2(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale); float kv_scale);
......
...@@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention( ...@@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
block_tables: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, seq_lens: torch.Tensor,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
) -> None: ) -> None:
...@@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention( ...@@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention(
num_seqs = query.shape[0] num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist() block_tables = block_tables.cpu().tolist()
context_lens = context_lens.cpu().tolist() seq_lens = seq_lens.cpu().tolist()
for i in range(num_seqs): for i in range(num_seqs):
q = query[i].unsqueeze(0) q = query[i].unsqueeze(0)
block_table = block_tables[i] block_table = block_tables[i]
context_len = int(context_lens[i]) seq_len = int(seq_lens[i])
keys = [] keys = []
values = [] values = []
for j in range(context_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
...@@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention( ...@@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None alibi_bias = None
if alibi_slopes is not None: if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel. # Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len).int() position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - context_len + 1).float() alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1) 1, 1, -1)
...@@ -149,13 +149,13 @@ def test_paged_attention( ...@@ -149,13 +149,13 @@ def test_paged_attention(
if use_alibi: if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN seq_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int) seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
...@@ -186,16 +186,15 @@ def test_paged_attention( ...@@ -186,16 +186,15 @@ def test_paged_attention(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
) )
elif version == "v2": elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0 assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty( tmp_output = torch.empty(
...@@ -218,9 +217,9 @@ def test_paged_attention( ...@@ -218,9 +217,9 @@ def test_paged_attention(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -255,7 +254,7 @@ def test_paged_attention( ...@@ -255,7 +254,7 @@ def test_paged_attention(
key_cache, key_cache,
value_cache, value_cache,
block_tables, block_tables,
context_lens, seq_lens,
scale, scale,
alibi_slopes, alibi_slopes,
) )
......
...@@ -51,12 +51,12 @@ def test_contexted_kv_attention( ...@@ -51,12 +51,12 @@ def test_contexted_kv_attention(
cache_size = 640 cache_size = 640
block_size = 32 block_size = 32
max_block_per_request = 64 max_block_per_request = 64
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(subquery_lens) num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3) query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
...@@ -75,15 +75,15 @@ def test_contexted_kv_attention( ...@@ -75,15 +75,15 @@ def test_contexted_kv_attention(
num_kv_heads, num_kv_heads,
head_size, head_size,
dtype=dtype) dtype=dtype)
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long) values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)] values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view( block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request) BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long), dtype=torch.long),
dim=0) dim=0)
max_input_len = MAX_SEQ_LEN max_input_len = MAX_SEQ_LEN
...@@ -92,7 +92,7 @@ def test_contexted_kv_attention( ...@@ -92,7 +92,7 @@ def test_contexted_kv_attention(
dtype=torch.long), dtype=torch.long),
dim=0) dim=0)
for i in range(BS): for i in range(BS):
for j in range(subquery_lens[i]): for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j]) j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
...@@ -178,7 +178,7 @@ def test_contexted_kv_attention( ...@@ -178,7 +178,7 @@ def test_contexted_kv_attention(
value = value.unsqueeze(0) value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens) query_lens, seq_lens)
if sliding_window > 0: if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright( attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window) sliding_window)
......
...@@ -58,7 +58,7 @@ def _do_sample( ...@@ -58,7 +58,7 @@ def _do_sample(
device: str, device: str,
): ):
seq_group_metadata_list = [] seq_group_metadata_list = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -68,12 +68,12 @@ def _do_sample( ...@@ -68,12 +68,12 @@ def _do_sample(
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
subquery_lens=prompt_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
...@@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"Invalid test case, need seq_group_metadata_list" "Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
prompt_lens = [] seq_lens = []
sampling_params_per_row = [] sampling_params_per_row = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
...@@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# a prompt seq_group has only one sequence # a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values())) seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len() prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len) seq_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs: if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in # with prompt_logprobs each token in the prompt has a row in
...@@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None, seq_lens=seq_lens if seq_lens else None,
subquery_lens=prompt_lens if prompt_lens else None, query_lens=seq_lens if seq_lens else None,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler # the logits tensor is modified in-place by the sampler
...@@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_group_metadata_list = [] seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = [] expected_tokens: List[Optional[List[int]]] = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3) sampling_type = random.randint(0, 3)
...@@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner): def test_sampling(model_runner: ModelRunner):
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
subquery_lens=prompt_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits, sampler_output = sampler(logits=fake_logits,
...@@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
# Shuffle the batch and resample # Shuffle the batch and resample
target_index = list(range(batch_size)) target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list, for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens): expected_tokens, seq_lens):
random.Random(seed).shuffle(list_to_shuffle) random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index) target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index) input_tensor.data = input_tensor.index_select(0, target_index)
...@@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert len(warpers) == 2 # top_p and top_k assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = [] seq_group_metadata_list = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
), ),
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
subquery_lens=prompt_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
......
...@@ -45,7 +45,7 @@ class AsyncLLM: ...@@ -45,7 +45,7 @@ class AsyncLLM:
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -66,7 +66,7 @@ class AsyncLLM: ...@@ -66,7 +66,7 @@ class AsyncLLM:
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space, swap_space=swap_space,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture,
engine_use_ray=True, engine_use_ray=True,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs, **kwargs,
......
...@@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int): ...@@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
list(range(block_size * 2)), list(range(block_size * 2)),
] ]
final_seq_lens = [ final_prompt_lens = [
len(prompt + output) + num_steps len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens) for prompt, output in zip(prompts, prev_output_tokens)
] ]
...@@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int): ...@@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_seq_lens, final_prompt_lens,
continuations=prev_output_tokens) continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
...@@ -103,17 +103,21 @@ def test_same_output_for_single_step(): ...@@ -103,17 +103,21 @@ def test_same_output_for_single_step():
[6, 7, 8, 9, 10], [6, 7, 8, 9, 10],
] ]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts] final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_execute_model_data = create_execute_model_data( multi_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts( seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, prompts,
final_seq_lens=final_seq_lens)) num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
single_step_execute_model_data = create_execute_model_data( single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts( seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, prompts,
final_seq_lens=final_seq_lens)) num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
zero_kv_cache(multi_step_worker.cache_engine) zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
...@@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): ...@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
random.randint(0, 1000) for _ in range(random.randint(10, 20)) random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)] ] for _ in range(10)]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts] final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds( multi_step_worker.execute_model = patch_execute_model_with_seeds(
...@@ -195,7 +199,7 @@ def test_same_output_for_multi_step(): ...@@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
continuations=continuations, continuations=continuations,
final_seq_lens=final_seq_lens), ) final_prompt_lens=final_prompt_lens), )
# Run multi-step. # Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine) zero_kv_cache(multi_step_worker.cache_engine)
...@@ -217,7 +221,7 @@ def test_same_output_for_multi_step(): ...@@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
continuations=continuations, continuations=continuations,
final_seq_lens=final_seq_lens)) final_prompt_lens=final_prompt_lens))
single_step_output.extend( single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), )) worker.execute_model(**execute_model_data.to_dict(), ))
......
...@@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match(): ...@@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match():
] ]
proposal_len = 5 proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data( ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts( seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, prompts,
final_seq_lens=final_seq_lens)) num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals( proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(), **ngram_sampler_output_data.to_dict(),
...@@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): ...@@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
] ]
proposal_len = 5 proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data( ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts( seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, prompts,
final_seq_lens=final_seq_lens)) num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals( proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(), **ngram_sampler_output_data.to_dict(),
...@@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all(): ...@@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all():
] ]
proposal_len = 5 proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data( ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts( seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, prompts,
final_seq_lens=final_seq_lens)) num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals( proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(), **ngram_sampler_output_data.to_dict(),
......
...@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts( ...@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
prompts: List[List[int]], prompts: List[List[int]],
num_gpu_blocks: int, num_gpu_blocks: int,
block_size: int, block_size: int,
final_seq_lens: List[int], final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None, continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None, seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
...@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts( ...@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks.pop() free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size)) for _ in range(round_up_to_next_block(final_len, block_size))
] ]
for i, final_len in enumerate(final_seq_lens) for i, final_len in enumerate(final_prompt_lens)
} }
return [ return [
...@@ -251,13 +251,13 @@ def create_batch(batch_size, ...@@ -251,13 +251,13 @@ def create_batch(batch_size,
prev_output_tokens = [[ prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len) next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)] ] for _ in range(batch_size)]
final_seq_lens = [ final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1 len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens) for prompt, prev_output_token in zip(prompts, prev_output_tokens)
] ]
execute_model_data = create_execute_model_data( execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
block_size, final_seq_lens, block_size, final_prompt_lens,
prev_output_tokens, seq_ids), ) prev_output_tokens, seq_ids), )
return execute_model_data, prompts, prev_output_tokens return execute_model_data, prompts, prev_output_tokens
...@@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str): ...@@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
return logits return logits
seq_group_metadata_list = [] seq_group_metadata_list = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str): ...@@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str):
logits_processors=[pick_ith]), logits_processors=[pick_ith]),
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
subquery_lens=prompt_lens, query_lens=seq_lens,
device=model_runner.device, device=model_runner.device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor( logits_processor_output = logits_processor(
......
...@@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size): ...@@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size):
lora_config=None) lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] seq_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
block_tables = {0: [1]} block_tables = {0: [1]}
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(prompt_len))) seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size): ...@@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
for prompt_len in prompt_lens: for seq_len in seq_lens:
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) seq_len - 1)
selected_token_start_idx += prompt_len selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_seq_lens == seq_lens
assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens) assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is True assert attn_metadata.is_prompt is True
assert torch.allclose(attn_metadata.prompt_lens_tensor, assert torch.allclose(
torch.tensor(prompt_lens, device=device)) attn_metadata.seq_lens_tensor,
assert attn_metadata.prompt_lens == prompt_lens torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_prompt_len == max(prompt_lens) assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
# Test subquery start locs. # Test subquery start locs.
start_idx = 0 start_idx = 0
start_loc = [start_idx] start_loc = [start_idx]
for prompt_len in prompt_lens: for seq_len in seq_lens:
start_idx += prompt_len start_idx += seq_len
start_loc.append(start_idx) start_loc.append(start_idx)
assert torch.allclose( assert torch.allclose(
attn_metadata.subquery_start_loc, attn_metadata.subquery_start_loc,
...@@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size): ...@@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size):
# equivalent to subquery_start_loc. # equivalent to subquery_start_loc.
start_idx = 0 start_idx = 0
seq_start_loc = [start_idx] seq_start_loc = [start_idx]
for prompt_len in prompt_lens: for seq_len in seq_lens:
start_idx += prompt_len start_idx += seq_len
seq_start_loc.append(start_idx) seq_start_loc.append(start_idx)
assert torch.allclose( assert torch.allclose(
attn_metadata.seq_start_loc, attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
assert attn_metadata.max_context_len is None
assert torch.allclose( assert torch.allclose(
attn_metadata.context_lens, attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens.shape[0], torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int, dtype=torch.int,
device=device)) device=device))
...@@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size): ...@@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size):
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
assert len(input_tokens) == sum(prompt_lens) assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(prompt_lens) assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
subquery_lens=prompt_lens, query_lens=seq_lens,
device=model_runner.device, device=model_runner.device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
assert len(input_tokens) == sum(prompt_lens) assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(prompt_lens) assert len(input_positions) == sum(seq_lens)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
...@@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
lora_config=None) lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] seq_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
seq_data = list(range(prompt_len)) seq_data = list(range(seq_len))
seq_data = SequenceData(seq_data) seq_data = SequenceData(seq_data)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
...@@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is False assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None assert attn_metadata.seq_lens is None
assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_context_len == max(prompt_lens) assert attn_metadata.max_seq_len == max(seq_lens)
assert torch.allclose( assert torch.allclose(
attn_metadata.context_lens[:len(prompt_lens)], attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device)) torch.tensor(seq_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in # block table's first index corresponds to each batch, meaning in
# decoding it is each token. # decoding it is each token.
...@@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
for prompt_len in prompt_lens: for seq_len in seq_lens:
expected_selected_token_indices.append(selected_token_start_idx) expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1 selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
subquery_lens=prompt_lens, query_lens=seq_lens,
device=model_runner.device, device=model_runner.device,
pin_memory=model_runner.pin_memory) pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
...@@ -241,14 +239,13 @@ def test_empty_seq_group(): ...@@ -241,14 +239,13 @@ def test_empty_seq_group():
assert attn_metadata is None assert attn_metadata is None
assert len(slot_mapping) == 0 assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0 assert len(input_tokens) == 0
assert len(input_positions) == 0 assert len(input_positions) == 0
assert attn_metadata is None assert attn_metadata is None
assert len(slot_mapping) == 0 assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0 assert len(return_seq_lens) == 0
@pytest.fixture @pytest.fixture
...@@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
model_runner.set_block_size(16) model_runner.set_block_size(16)
# Add prefill requests. # Add prefill requests.
prompt_lens = [] seq_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
prefill_metadata_list = [] prefill_metadata_list = []
decode_metadata_list = [] decode_metadata_list = []
...@@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_batch_size = batch_size - prefill_batch_size decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size): for i in range(prefill_batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(prompt_len))) seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests # Add decode requests
for i in range(prefill_batch_size, batch_size): for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len)) prompt_toks = list(range(seq_len))
seq_data = SequenceData(prompt_toks) seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
...@@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
else: else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size( assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size) decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens) assert attn_metadata.num_prefill_tokens == sum(seq_lens)
# Verify attn metadata is consistent. We don't need to test individual # Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above. # values here because they are tested above.
......
...@@ -39,17 +39,17 @@ def paged_attention_v1( ...@@ -39,17 +39,17 @@ def paged_attention_v1(
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, seq_lens: torch.Tensor,
block_size: int, block_size: int,
max_context_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, kv_scale: float,
) -> None: ) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, num_kv_heads, scale, block_tables, seq_lens,
context_lens, block_size, max_context_len, block_size, max_seq_len, alibi_slopes,
alibi_slopes, kv_cache_dtype, kv_scale) kv_cache_dtype, kv_scale)
def paged_attention_v2( def paged_attention_v2(
...@@ -63,17 +63,17 @@ def paged_attention_v2( ...@@ -63,17 +63,17 @@ def paged_attention_v2(
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, seq_lens: torch.Tensor,
block_size: int, block_size: int,
max_context_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, kv_scale: float,
) -> None: ) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale, key_cache, value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size, block_tables, seq_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale) kv_scale)
......
...@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, ...@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding. # (batch_size,). The sequence length per sequence. Sequence length means
prompt_lens: Optional[List[int]] # the computed tokens + new tokens None if it is a decoding.
# prompt_lens stored as a tensor. seq_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor] # seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---| # |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------| # |---------- context_len ----------|
# |-------------------- seqlen ----------------------| # |-------------------- seq_len ----------------------|
# |- subquery_len -| # |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is # Maximum query length in the batch.
# prefill vs decoding. When it is prefill, it doesn't include new tokens. max_query_len: Optional[int]
# When it is for decoding, it includes a new token. # Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10]. # is [4, 6], it is [0, 4, 10].
...@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, ...@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# the batch, used to index into sequence. E.g., if the sequence length is # the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10]. # [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled. # Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
...@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
v=value, v=value,
cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len, max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_prompt_len, max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
...@@ -245,9 +245,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -245,9 +245,9 @@ class FlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
prefill_meta.block_tables, prefill_meta.block_tables,
prefill_meta.subquery_start_loc, prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor, prefill_meta.seq_lens_tensor,
prefill_meta.context_lens, prefill_meta.context_lens_tensor,
prefill_meta.max_subquery_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window[0], self.sliding_window[0],
) )
...@@ -258,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -258,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
key_cache, key_cache,
value_cache, value_cache,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.context_lens, decode_meta.seq_lens_tensor,
decode_meta.max_context_len, decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
......
...@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, ...@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding. # (batch_size,). The sequence length per sequence. Sequence length means
prompt_lens: Optional[List[int]] # the computed tokens + new tokens None if it is a decoding.
# prompt_lens stored as a tensor. seq_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor] # seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---| # |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------| # |---------- context_len ----------|
# |-------------------- seqlen ----------------------| # |-------------------- seq_len ----------------------|
# |- subquery_len -| # |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is # Maximum query length in the batch.
# prefill vs decoding. When it is prefill, it doesn't include new tokens. max_query_len: Optional[int]
# When it is for decoding, it includes a new token. # Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10]. # is [4, 6], it is [0, 4, 10].
...@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, ...@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
class ROCmFlashAttentionImpl(AttentionImpl): class ROCmFlashAttentionImpl(AttentionImpl):
...@@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
assert prefill_meta.prompt_lens is not None assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
...@@ -260,8 +260,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -260,8 +260,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
None, None,
prefill_meta.seq_start_loc, prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc, prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len, prefill_meta.max_seq_len,
prefill_meta.max_prompt_len, prefill_meta.max_seq_len,
True, True,
self.scale, self.scale,
) )
...@@ -274,7 +274,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -274,7 +274,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query, query,
key, key,
value, value,
prefill_meta.prompt_lens, prefill_meta.seq_lens,
self.scale, self.scale,
) )
else: else:
...@@ -284,8 +284,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -284,8 +284,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v=value, v=value,
cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len, max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_prompt_len, max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )
...@@ -303,9 +303,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -303,9 +303,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
prefill_meta.block_tables, prefill_meta.block_tables,
prefill_meta.subquery_start_loc, prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor, prefill_meta.seq_lens_tensor,
prefill_meta.context_lens, prefill_meta.context_lens_tensor,
prefill_meta.max_subquery_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window[0], self.sliding_window[0],
) )
...@@ -317,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -317,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache, key_cache,
value_cache, value_cache,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.context_lens, decode_meta.seq_lens_tensor,
decode_meta.max_context_len, decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
...@@ -334,13 +334,13 @@ def _naive_attention( ...@@ -334,13 +334,13 @@ def _naive_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
prompt_lens: List[int], seq_lens: List[int],
scale: float, scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)
start = 0 start = 0
for _, prompt_len in enumerate(prompt_lens): for _, seq_len in enumerate(seq_lens):
end = start + prompt_len end = start + seq_len
out = _naive_masked_attention( out = _naive_masked_attention(
query[start:end], query[start:end],
key[start:end], key[start:end],
...@@ -349,7 +349,7 @@ def _naive_attention( ...@@ -349,7 +349,7 @@ def _naive_attention(
) )
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out) output[start:end].copy_(out)
start += prompt_len start += seq_len
return output return output
......
...@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, ...@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]] seq_lens: Optional[List[int]]
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
...@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale) kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
...@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
if self.alibi_slopes is not None: if self.alibi_slopes is not None:
att_masks = _make_alibi_bias( att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype, self.alibi_slopes, query.dtype,
attn_metadata.prompt_lens) # type: ignore attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None: elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias( att_masks = _make_sliding_window_bias(
attn_metadata.prompt_lens, self.sliding_window, attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore query.dtype) # type: ignore
else: else:
att_masks = [None] * len(attn_metadata.prompt_lens) att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2) query = query.movedim(0, query.dim() - 2)
...@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
output = torch.empty( output = torch.empty(
(num_tokens, self.num_heads, self.head_size), (num_tokens, self.num_heads, self.head_size),
dtype=query.dtype) dtype=query.dtype)
for prompt_len, mask in zip(attn_metadata.prompt_lens, for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias): attn_metadata.attn_bias):
end = start + prompt_len end = start + seq_len
sub_out = scaled_dot_product_attention( sub_out = scaled_dot_product_attention(
query[:, start:end, :], query[:, start:end, :],
key[:, start:end, :], key[:, start:end, :],
...@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, attn_metadata.block_tables,
attn_metadata.context_lens, attn_metadata.seq_lens_tensor,
attn_metadata.max_context_len, attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
...@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
prompt_lens: List[int], seq_lens: List[int],
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
attn_biases = [] attn_biases = []
for prompt_len in prompt_lens: for seq_len in seq_lens:
bias = torch.arange(prompt_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
# paper. # paper.
...@@ -221,7 +221,7 @@ def _make_alibi_bias( ...@@ -221,7 +221,7 @@ def _make_alibi_bias(
bias = bias[None, :].repeat((num_heads, 1, 1)) bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty( inf_mask = torch.empty(
(1, prompt_len, prompt_len), (1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype)) attn_biases.append((bias + inf_mask).to(dtype))
...@@ -229,14 +229,14 @@ def _make_alibi_bias( ...@@ -229,14 +229,14 @@ def _make_alibi_bias(
def _make_sliding_window_bias( def _make_sliding_window_bias(
prompt_lens: List[int], seq_lens: List[int],
window_size: Optional[int], window_size: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
attn_biases = [] attn_biases = []
for prompt_len in prompt_lens: for seq_len in seq_lens:
tensor = torch.full( tensor = torch.full(
(1, prompt_len, prompt_len), (1, seq_len, seq_len),
dtype=dtype, dtype=dtype,
fill_value=1, fill_value=1,
) )
......
...@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): ...@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding. # (batch_size,). The sequence length per sequence. Sequence length means
prompt_lens: Optional[List[int]] # the computed tokens + new tokens None if it is a decoding.
# prompt_lens stored as a tensor. seq_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor] # seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---| # |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------| # |---------- context_len ----------|
# |-------------------- seqlen ----------------------| # |-------------------- seq_len ----------------------|
# |- subquery_len -| # |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is # Maximum query length in the batch.
# prefill vs decoding. When it is prefill, it doesn't include new tokens. max_query_len: Optional[int]
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# FIXME: It is for flash attn. # FIXME: It is for flash attn.
# Maximum prompt length in the batch. # Maximum sequence length in the batch.
max_prompt_len: Optional[int] max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10]. # is [4, 6], it is [0, 4, 10].
...@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): ...@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is # the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10]. # [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled. # Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
...@@ -242,9 +241,9 @@ class XFormersImpl(AttentionImpl): ...@@ -242,9 +241,9 @@ class XFormersImpl(AttentionImpl):
value_cache, value_cache,
prefill_meta.block_tables, prefill_meta.block_tables,
prefill_meta.subquery_start_loc, prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor, prefill_meta.seq_lens_tensor,
prefill_meta.context_lens, prefill_meta.context_lens_tensor,
prefill_meta.max_subquery_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window, self.sliding_window,
) )
...@@ -257,8 +256,8 @@ class XFormersImpl(AttentionImpl): ...@@ -257,8 +256,8 @@ class XFormersImpl(AttentionImpl):
key_cache, key_cache,
value_cache, value_cache,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.context_lens, decode_meta.seq_lens_tensor,
decode_meta.max_context_len, decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
...@@ -289,7 +288,7 @@ class XFormersImpl(AttentionImpl): ...@@ -289,7 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
""" """
assert attn_metadata.prompt_lens is not None assert attn_metadata.seq_lens is not None
original_query = query original_query = query
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K]. # GQA/MQA requires the shape [B, M, G, H, K].
...@@ -310,7 +309,7 @@ class XFormersImpl(AttentionImpl): ...@@ -310,7 +309,7 @@ class XFormersImpl(AttentionImpl):
if attn_metadata.attn_bias is None: if attn_metadata.attn_bias is None:
if self.alibi_slopes is None: if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.prompt_lens) attn_metadata.seq_lens)
if self.sliding_window is not None: if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention( attn_bias = attn_bias.make_local_attention(
self.sliding_window) self.sliding_window)
...@@ -318,7 +317,7 @@ class XFormersImpl(AttentionImpl): ...@@ -318,7 +317,7 @@ class XFormersImpl(AttentionImpl):
else: else:
attn_metadata.attn_bias = _make_alibi_bias( attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype, self.alibi_slopes, self.num_kv_heads, query.dtype,
attn_metadata.prompt_lens) attn_metadata.seq_lens)
# No alibi slopes. # No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
...@@ -343,8 +342,8 @@ class XFormersImpl(AttentionImpl): ...@@ -343,8 +342,8 @@ class XFormersImpl(AttentionImpl):
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(original_query) output = torch.empty_like(original_query)
start = 0 start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens): for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + prompt_len end = start + seq_len
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query[None, start:end], query[None, start:end],
key[None, start:end], key[None, start:end],
...@@ -354,7 +353,7 @@ class XFormersImpl(AttentionImpl): ...@@ -354,7 +353,7 @@ class XFormersImpl(AttentionImpl):
scale=self.scale) scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.view_as(original_query[start:end])) output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len start += seq_len
return output return output
...@@ -362,13 +361,13 @@ def _make_alibi_bias( ...@@ -362,13 +361,13 @@ def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
num_kv_heads: int, num_kv_heads: int,
dtype: torch.dtype, dtype: torch.dtype,
prompt_lens: List[int], seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias: ) -> LowerTriangularMaskWithTensorBias:
attn_biases = [] attn_biases = []
for prompt_len in prompt_lens: for seq_len in seq_lens:
bias = torch.arange(prompt_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
# paper. # paper.
...@@ -376,16 +375,16 @@ def _make_alibi_bias( ...@@ -376,16 +375,16 @@ def _make_alibi_bias(
# element. # element.
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
padded_len = (prompt_len + 7) // 8 * 8 padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0] num_heads = alibi_slopes.shape[0]
bias = torch.empty( bias = torch.empty(
1, # batch size 1, # batch size
num_heads, num_heads,
prompt_len, seq_len,
padded_len, padded_len,
device=alibi_slopes.device, device=alibi_slopes.device,
dtype=dtype, dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias) )[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads: if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
......
...@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512 ...@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
"""Metadata for PagedAttention.""" """Metadata for PagedAttention."""
# (batch_size,). The length of context (tokens stored in KV cache) per # (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence. WARNING: When it is a prefill request, it doesn't include new # sequence.
# tokens. When it is for decoding, it includes a new token. seq_lens_tensor: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor] # Maximum sequence length in the batch.
# Maximum context length in the batch. max_seq_len: Optional[int]
max_context_len: Optional[int]
# (batch_size, max_blocks_per_seq). # (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block) # Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
...@@ -85,8 +84,8 @@ class PagedAttention: ...@@ -85,8 +84,8 @@ class PagedAttention:
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
block_tables: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, seq_lens: torch.Tensor,
max_context_len: int, max_seq_len: int,
kv_cache_dtype: str, kv_cache_dtype: str,
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
...@@ -97,7 +96,7 @@ class PagedAttention: ...@@ -97,7 +96,7 @@ class PagedAttention:
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE) _PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
...@@ -106,7 +105,7 @@ class PagedAttention: ...@@ -106,7 +105,7 @@ class PagedAttention:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_context_len <= 8192 use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)) and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
...@@ -118,9 +117,9 @@ class PagedAttention: ...@@ -118,9 +117,9 @@ class PagedAttention:
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -150,9 +149,9 @@ class PagedAttention: ...@@ -150,9 +149,9 @@ class PagedAttention:
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -168,9 +167,9 @@ class PagedAttention: ...@@ -168,9 +167,9 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
block_tables: torch.Tensor, block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor, subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor, seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
max_subquery_len: int, max_query_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int], sliding_window: Optional[int],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -185,9 +184,9 @@ class PagedAttention: ...@@ -185,9 +184,9 @@ class PagedAttention:
block_tables, block_tables,
# subquery_start_loc is (batch_size + 1,) # subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1], subquery_start_loc[:-1],
prompt_lens_tensor, seq_lens_tensor,
context_lens, context_lens,
max_subquery_len, max_query_len,
alibi_slopes, alibi_slopes,
sliding_window, sliding_window,
) )
......
...@@ -63,7 +63,10 @@ class ModelConfig: ...@@ -63,7 +63,10 @@ class ModelConfig:
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. detokenizer.
""" """
...@@ -84,6 +87,7 @@ class ModelConfig: ...@@ -84,6 +87,7 @@ class ModelConfig:
quantization_param_path: Optional[str] = None, quantization_param_path: Optional[str] = None,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5, max_logprobs: int = 5,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
) -> None: ) -> None:
...@@ -99,6 +103,11 @@ class ModelConfig: ...@@ -99,6 +103,11 @@ class ModelConfig:
self.quantization_param_path = quantization_param_path self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture self.max_context_len_to_capture = max_context_len_to_capture
if self.max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
...@@ -190,10 +199,10 @@ class ModelConfig: ...@@ -190,10 +199,10 @@ class ModelConfig:
"non-quantized models.", self.quantization) "non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None: def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None: if self.max_seq_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len self.max_seq_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
...@@ -772,8 +781,8 @@ class SpeculativeConfig: ...@@ -772,8 +781,8 @@ class SpeculativeConfig:
max_model_len=None, max_model_len=None,
quantization=draft_quantization, quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager, enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config. max_seq_len_to_capture=target_model_config.
max_context_len_to_capture, max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs, max_logprobs=target_model_config.max_logprobs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment