Commit 499b1dc6 authored by PanZezhong's avatar PanZezhong
Browse files

issue/867 pass total kv lens as paged attn args

parent 0a2839a2
......@@ -16,7 +16,7 @@ public:
* 3. k_cache: Physical Key cache (Paged format)
* 4. v_cache: Physical Value cache (Paged format)
* 5. block_tables: Mapping table from logical blocks to physical blocks
* 6. history_lens: Historical KV lengths (existing length of each sequence in cache)
* 6. total_kv_lens: lengths of Complete Key/Value for each request
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
* 8. alibi_slopes: ALiBi bias slopes (optional)
* 9. scale: Scaling factor (typically 1/sqrt(head_size))
......@@ -24,7 +24,7 @@ public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
Tensor block_tables, Tensor total_kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale);
static common::OpDispatcher<schema> &dispatcher();
......@@ -34,8 +34,8 @@ Tensor paged_attention_prefill(Tensor q,
Tensor k_cache,
Tensor v_cache,
Tensor block_tables,
Tensor history_lens,
Tensor cu_seqlens_q,
Tensor total_kv_lens,
Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes,
float scale);
......@@ -44,8 +44,8 @@ void paged_attention_prefill_(Tensor out,
Tensor k_cache,
Tensor v_cache,
Tensor block_tables,
Tensor history_lens,
Tensor cu_seqlens_q,
Tensor total_kv_lens,
Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes,
float scale);
......
......@@ -20,7 +20,7 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* Shape: [batch_size, max_blocks_per_seq]
* @param history_lens_desc Descriptor for the KV history lengths of each sequence.
* @param seq_lens_desc Descriptor for the total KV lengths of each sequence.
* Shape: [batch_size]
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
* Shape: [batch_size + 1]
......@@ -37,7 +37,7 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);
......@@ -58,7 +58,7 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param history_lens Pointer to the KV history lengths data.
* @param seq_lens Pointer to the KV lengths data.
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The device stream (e.g., cudaStream_t) for the operation.
......@@ -73,7 +73,7 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *history_lens,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream);
......
......@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return dispatcher_;
};
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
}
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
return out;
}
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
}
} // namespace infinicore::op
......@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
}
});
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
......@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(),
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), kv_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
......@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(),
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), kv_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
......
......@@ -10,31 +10,30 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
};
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables,
history_lens, cu_seqlens_q, alibi_slopes, scale);
kv_lens, cum_seqlens_q, alibi_slopes, scale);
}
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
return out;
}
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
}
} // namespace infinicore::op
......@@ -16,10 +16,9 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
});
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
......@@ -35,8 +34,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
k_cache->desc(),
v_cache->desc(),
block_tables->desc(),
history_lens->desc(),
cu_seqlens_q->desc(),
kv_lens->desc(),
cum_seqlens_q->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
......@@ -57,8 +56,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
k_cache->data(),
v_cache->data(),
block_tables->data(),
history_lens->data(),
cu_seqlens_q->data(),
kv_lens->data(),
cum_seqlens_q->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
......
......@@ -19,7 +19,8 @@ Tensor py_paged_attention_prefill(Tensor q,
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::paged_attention_prefill(q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
return op::paged_attention_prefill(
q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
}
void py_paged_attention_prefill_(Tensor out,
......
......@@ -22,12 +22,13 @@ template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *history_lens_,
const int64_t *total_kv_lens_,
const int64_t *cum_seq_lens_q_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride, const ptrdiff_t q_head_stride,
const size_t head_size,
const size_t num_seqs) {
......@@ -44,10 +45,12 @@ __global__ void pagedAttentionPrefillKernel(
size_t q_token_idx = global_token_idx - cum_seq_lens_q_[seq_idx];
const int64_t history_len = history_lens_[seq_idx];
const int64_t causal_limit = history_len + q_token_idx;
const size_t total_kv_len = total_kv_lens_[seq_idx];
const size_t q_len = cum_seq_lens_q_[seq_idx + 1] - cum_seq_lens_q_[seq_idx];
const size_t history_len = total_kv_len - q_len;
const size_t causal_limit = history_len + q_token_idx;
const Tdata *q_vec = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
......@@ -57,10 +60,10 @@ __global__ void pagedAttentionPrefillKernel(
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
Tcompute max_score = -FLT_MAX;
for (int64_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t b_idx = t / block_size;
const size_t t_off = t % block_size;
const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
......@@ -77,10 +80,10 @@ __global__ void pagedAttentionPrefillKernel(
}
Tcompute sum_exp = 0.0f;
for (int64_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t b_idx = t / block_size;
const size_t t_off = t % block_size;
const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
......@@ -96,10 +99,10 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute acc = 0.0f;
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
for (int64_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t b_idx = t / block_size;
const size_t t_off = t % block_size;
const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
......
......@@ -25,6 +25,7 @@ public:
size_t total_q_tokens;
ptrdiff_t q_stride;
ptrdiff_t q_head_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t o_stride;
......@@ -35,7 +36,7 @@ public:
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
......@@ -47,7 +48,7 @@ public:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || history_lens_desc->dtype() != INFINI_DTYPE_I64) {
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -57,7 +58,7 @@ public:
auto k_shape = k_cache_desc->shape();
auto v_shape = v_cache_desc->shape();
auto block_tables_shape = block_tables_desc->shape();
auto history_lens_shape = history_lens_desc->shape();
auto seq_lens_shape = seq_lens_desc->shape();
auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape();
if (k_shape.size() != 4 || v_shape.size() != 4) {
......@@ -68,10 +69,11 @@ public:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (history_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (cum_seq_lens_q_shape[0] != history_lens_shape[0] + 1) {
if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) {
return INFINI_STATUS_BAD_PARAM;
}
......@@ -88,13 +90,13 @@ public:
return INFINI_STATUS_BAD_PARAM;
}
size_t num_seqs = history_lens_shape[0];
size_t num_seqs = seq_lens_shape[0];
size_t num_kv_heads = k_shape[1];
size_t block_size = k_shape[2];
size_t max_num_blocks_per_seq = block_tables_shape[1];
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t q_head_stride = q_desc->stride(1);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
......@@ -110,6 +112,7 @@ public:
max_num_blocks_per_seq,
total_q_tokens,
q_stride,
q_head_stride,
kv_block_stride,
kv_head_stride,
o_stride});
......
......@@ -12,7 +12,7 @@ template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *history_lens,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
......@@ -25,6 +25,8 @@ infiniStatus_t launchPagedAttentionPrefill(
const size_t head_size,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
cudaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
......@@ -37,10 +39,11 @@ infiniStatus_t launchPagedAttentionPrefill(
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, history_lens, cum_seq_lens_q, alibi_slopes,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
......@@ -65,14 +68,14 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, history_lens_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
......@@ -89,23 +92,24 @@ infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *history_lens,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)history_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
......
......@@ -14,7 +14,7 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale) {
......@@ -27,14 +27,15 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \
history_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
seq_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
......@@ -50,8 +51,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopPagedAttentionPrefill(
......@@ -59,7 +61,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *history_lens,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream) {
......@@ -68,14 +70,15 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
history_lens, cum_seq_lens_q, alibi_slopes, stream);
seq_lens, cum_seq_lens_q, alibi_slopes, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
......@@ -90,6 +93,7 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -37,7 +37,7 @@
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t history_lens_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t cum_seq_lens_q_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
......@@ -46,7 +46,7 @@
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, \
const void *history_lens, \
const void *seq_lens, \
const void *cum_seq_lens_q, \
const void *alibi_slopes, \
void *stream) const; \
......
import sys
import os
import sys
import torch
import infinicore
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
TensorInitializer,
TensorSpec,
TestCase,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
......@@ -71,16 +73,17 @@ def parse_test_cases():
scale = head_size**-0.5
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
current_history_lens = torch.zeros(num_seqs, dtype=torch.int64)
kv_lens = torch.zeros(num_seqs, dtype=torch.int64)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int64)
kv_lens = kv_lens + q_lens
total_q_tokens = q_lens.sum().item()
cu_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64)
cu_seqlens_q[1:] = torch.cumsum(q_lens, dim=0)
cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64)
cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0)
query_base = torch.randn((total_q_tokens, num_heads, head_size))
......@@ -89,8 +92,7 @@ def parse_test_cases():
p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item())
round_block_tables_list.append(p_blocks)
h_len = current_history_lens[i].item()
q_start = cu_seqlens_q[i].item()
h_len = kv_lens[i].item() - q_lens[i].item()
for t in range(q_lens[i].item()):
logical_pos = h_len + t
......@@ -135,15 +137,15 @@ def parse_test_cases():
dtype=infinicore.int64,
),
TensorSpec.from_tensor(
current_history_lens.shape,
kv_lens.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=current_history_lens.clone(),
set_tensor=kv_lens.clone(),
dtype=infinicore.int64,
),
TensorSpec.from_tensor(
cu_seqlens_q.shape,
cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=cu_seqlens_q.clone(),
set_tensor=cum_seqlens_q.clone(),
dtype=infinicore.int64,
),
],
......@@ -153,23 +155,21 @@ def parse_test_cases():
)
)
current_history_lens += q_lens
return test_cases
def ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, scale
query, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, scale
):
output = torch.zeros_like(query)
num_seqs = len(history_lens)
num_seqs = len(kv_lens)
block_size = k_cache.shape[2]
for i in range(num_seqs):
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
q_start, q_end = cum_seqlens_q[i].item(), cum_seqlens_q[i + 1].item()
cur_q = query[q_start:q_end]
h_len = history_lens[i].item()
q_len = q_end - q_start
h_len = kv_lens[i].item() - q_len
total_len = h_len + q_len
table = block_tables[i]
......@@ -206,12 +206,12 @@ class OpTest(BaseOperatorTest):
k_cache,
v_cache,
block_tables,
history_lens,
cu_seqlens_q,
kv_lens,
cum_seqlens_q,
scale=1.0,
):
return ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, scale
query, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, scale
)
def infinicore_operator(
......@@ -220,8 +220,8 @@ class OpTest(BaseOperatorTest):
k_cache,
v_cache,
block_tables,
history_lens,
cu_seqlens_q,
kv_lens,
cum_seqlens_q,
scale=1.0,
):
out = infinicore.paged_attention_prefill(
......@@ -229,8 +229,8 @@ class OpTest(BaseOperatorTest):
k_cache,
v_cache,
block_tables,
history_lens,
cu_seqlens_q,
kv_lens,
cum_seqlens_q,
alibi_slopes=None,
scale=scale,
)
......
import torch
import ctypes
from ctypes import c_uint64
import torch
from libinfiniop import (
LIBINFINIOP,
InfiniDeviceNames,
InfiniDtype,
InfiniDtypeNames,
TestTensor,
get_test_devices,
TestWorkspace,
check_error,
test_operator,
get_args,
debug,
get_args,
get_test_devices,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
profile_operation,
test_operator,
)
# ==============================================================================
......@@ -81,8 +82,8 @@ def ref_paged_attention_multi_turn(
num_seqs = len(cum_seq_lens_q) - 1
for i in range(num_seqs):
num_new = cum_seq_lens_q[i + 1].item() - cum_seq_lens_q[i].item()
cache_len = seq_lens[i].item()
total_len = seq_lens[i].item() + num_new
total_len = seq_lens[i].item()
cache_len = seq_lens[i].item() - num_new
table = block_tables[i]
keys_all, values_all = [], []
......@@ -166,7 +167,7 @@ def test(
cur_q_len = query_lens_cpu[i].item()
table, total_len = manager.allocate_slots(i, cur_q_len)
cur_seq_lens = total_len - cur_q_len
seq_lens_list.append(cur_seq_lens)
seq_lens_list.append(total_len)
all_block_tables.append(table)
# Simulated KV insertion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment