Commit 499b1dc6 authored by PanZezhong's avatar PanZezhong
Browse files

issue/867 pass total kv lens as paged attn args

parent 0a2839a2
...@@ -16,7 +16,7 @@ public: ...@@ -16,7 +16,7 @@ public:
* 3. k_cache: Physical Key cache (Paged format) * 3. k_cache: Physical Key cache (Paged format)
* 4. v_cache: Physical Value cache (Paged format) * 4. v_cache: Physical Value cache (Paged format)
* 5. block_tables: Mapping table from logical blocks to physical blocks * 5. block_tables: Mapping table from logical blocks to physical blocks
* 6. history_lens: Historical KV lengths (existing length of each sequence in cache) * 6. total_kv_lens: lengths of Complete Key/Value for each request
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch) * 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
* 8. alibi_slopes: ALiBi bias slopes (optional) * 8. alibi_slopes: ALiBi bias slopes (optional)
* 9. scale: Scaling factor (typically 1/sqrt(head_size)) * 9. scale: Scaling factor (typically 1/sqrt(head_size))
...@@ -24,7 +24,7 @@ public: ...@@ -24,7 +24,7 @@ public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float); using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q, Tensor block_tables, Tensor total_kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale); std::optional<Tensor> alibi_slopes, float scale);
static common::OpDispatcher<schema> &dispatcher(); static common::OpDispatcher<schema> &dispatcher();
...@@ -34,8 +34,8 @@ Tensor paged_attention_prefill(Tensor q, ...@@ -34,8 +34,8 @@ Tensor paged_attention_prefill(Tensor q,
Tensor k_cache, Tensor k_cache,
Tensor v_cache, Tensor v_cache,
Tensor block_tables, Tensor block_tables,
Tensor history_lens, Tensor total_kv_lens,
Tensor cu_seqlens_q, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, std::optional<Tensor> alibi_slopes,
float scale); float scale);
...@@ -44,8 +44,8 @@ void paged_attention_prefill_(Tensor out, ...@@ -44,8 +44,8 @@ void paged_attention_prefill_(Tensor out,
Tensor k_cache, Tensor k_cache,
Tensor v_cache, Tensor v_cache,
Tensor block_tables, Tensor block_tables,
Tensor history_lens, Tensor total_kv_lens,
Tensor cu_seqlens_q, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, std::optional<Tensor> alibi_slopes,
float scale); float scale);
......
...@@ -20,7 +20,7 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t; ...@@ -20,7 +20,7 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size] * Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks. * @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* Shape: [batch_size, max_blocks_per_seq] * Shape: [batch_size, max_blocks_per_seq]
* @param history_lens_desc Descriptor for the KV history lengths of each sequence. * @param seq_lens_desc Descriptor for the total KV lengths of each sequence.
* Shape: [batch_size] * Shape: [batch_size]
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence. * @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
* Shape: [batch_size + 1] * Shape: [batch_size + 1]
...@@ -37,7 +37,7 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( ...@@ -37,7 +37,7 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc, infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc, infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc, infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc, infiniopTensorDescriptor_t alibi_slopes_desc,
float scale); float scale);
...@@ -58,7 +58,7 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( ...@@ -58,7 +58,7 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
* @param k_cache Pointer to the global key cache data. * @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data. * @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data. * @param block_tables Pointer to the block tables data.
* @param history_lens Pointer to the KV history lengths data. * @param seq_lens Pointer to the KV lengths data.
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum). * @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL. * @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The device stream (e.g., cudaStream_t) for the operation. * @param stream The device stream (e.g., cudaStream_t) for the operation.
...@@ -73,7 +73,7 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill( ...@@ -73,7 +73,7 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
const void *k_cache, const void *k_cache,
const void *v_cache, const void *v_cache,
const void *block_tables, const void *block_tables,
const void *history_lens, const void *seq_lens,
const void *cum_seq_lens_q, const void *cum_seq_lens_q,
const void *alibi_slopes, const void *alibi_slopes,
void *stream); void *stream);
......
...@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() { ...@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return dispatcher_; return dispatcher_;
}; };
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) { void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens);
infinicore::context::setDevice(out->device()); infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
} }
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) { Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
return out; return out;
} }
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) { void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
} }
} // namespace infinicore::op } // namespace infinicore::op
...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches( ...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
} }
}); });
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) { void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
auto device = context::getDevice(); auto device = context::getDevice();
auto &cache = caches.getCache(device); auto &cache = caches.getCache(device);
...@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc ...@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor( INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc, context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(), out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), kv_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale)); scale));
cache.put(seed, desc); cache.put(seed, desc);
...@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc ...@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR(infiniopPagedAttention( INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size, desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(), out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), kv_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream())); context::getStream()));
} }
......
...@@ -10,31 +10,30 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp ...@@ -10,31 +10,30 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
}; };
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q, Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) { std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q);
infinicore::context::setDevice(out->device()); infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables,
history_lens, cu_seqlens_q, alibi_slopes, scale); kv_lens, cum_seqlens_q, alibi_slopes, scale);
} }
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q, Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) { std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale); paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
return out; return out;
} }
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q, Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) { std::optional<Tensor> alibi_slopes, float scale) {
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale); PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
} }
} // namespace infinicore::op } // namespace infinicore::op
...@@ -16,10 +16,9 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t> ...@@ -16,10 +16,9 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
}); });
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q, Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) { std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
auto device = context::getDevice(); auto device = context::getDevice();
auto &cache = caches.getCache(device); auto &cache = caches.getCache(device);
...@@ -35,8 +34,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, ...@@ -35,8 +34,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
k_cache->desc(), k_cache->desc(),
v_cache->desc(), v_cache->desc(),
block_tables->desc(), block_tables->desc(),
history_lens->desc(), kv_lens->desc(),
cu_seqlens_q->desc(), cum_seqlens_q->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale)); scale));
cache.put(seed, desc); cache.put(seed, desc);
...@@ -57,8 +56,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, ...@@ -57,8 +56,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
k_cache->data(), k_cache->data(),
v_cache->data(), v_cache->data(),
block_tables->data(), block_tables->data(),
history_lens->data(), kv_lens->data(),
cu_seqlens_q->data(), cum_seqlens_q->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream())); context::getStream()));
} }
......
...@@ -19,7 +19,8 @@ Tensor py_paged_attention_prefill(Tensor q, ...@@ -19,7 +19,8 @@ Tensor py_paged_attention_prefill(Tensor q,
if (!alibi_slopes.is_none()) { if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>(); alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
} }
return op::paged_attention_prefill(q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale); return op::paged_attention_prefill(
q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
} }
void py_paged_attention_prefill_(Tensor out, void py_paged_attention_prefill_(Tensor out,
......
...@@ -22,12 +22,13 @@ template <typename Tdata, typename Tcompute> ...@@ -22,12 +22,13 @@ template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel( __global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_, Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_, const int64_t *block_tables_,
const int64_t *history_lens_, const int64_t *total_kv_lens_,
const int64_t *cum_seq_lens_q_, const int64_t *cum_seq_lens_q_,
const float *alibi_slopes_, const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale, const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size, const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride, const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride, const ptrdiff_t q_head_stride,
const size_t head_size, const size_t head_size,
const size_t num_seqs) { const size_t num_seqs) {
...@@ -44,10 +45,12 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -44,10 +45,12 @@ __global__ void pagedAttentionPrefillKernel(
size_t q_token_idx = global_token_idx - cum_seq_lens_q_[seq_idx]; size_t q_token_idx = global_token_idx - cum_seq_lens_q_[seq_idx];
const int64_t history_len = history_lens_[seq_idx]; const size_t total_kv_len = total_kv_lens_[seq_idx];
const int64_t causal_limit = history_len + q_token_idx; const size_t q_len = cum_seq_lens_q_[seq_idx + 1] - cum_seq_lens_q_[seq_idx];
const size_t history_len = total_kv_len - q_len;
const size_t causal_limit = history_len + q_token_idx;
const Tdata *q_vec = q_ + global_token_idx * num_heads * head_size + head_idx * head_size; const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const size_t num_queries_per_kv = num_heads / num_kv_heads; const size_t num_queries_per_kv = num_heads / num_kv_heads;
...@@ -57,10 +60,10 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -57,10 +60,10 @@ __global__ void pagedAttentionPrefillKernel(
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
Tcompute max_score = -FLT_MAX; Tcompute max_score = -FLT_MAX;
for (int64_t t = 0; t <= causal_limit; ++t) { for (size_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size; const size_t b_idx = t / block_size;
const int64_t t_off = t % block_size; const size_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx]; const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f; Tcompute score = 0.0f;
...@@ -77,10 +80,10 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -77,10 +80,10 @@ __global__ void pagedAttentionPrefillKernel(
} }
Tcompute sum_exp = 0.0f; Tcompute sum_exp = 0.0f;
for (int64_t t = 0; t <= causal_limit; ++t) { for (size_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size; const size_t b_idx = t / block_size;
const int64_t t_off = t % block_size; const size_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx]; const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f; Tcompute score = 0.0f;
...@@ -96,10 +99,10 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -96,10 +99,10 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute acc = 0.0f; Tcompute acc = 0.0f;
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f); Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
for (int64_t t = 0; t <= causal_limit; ++t) { for (size_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size; const size_t b_idx = t / block_size;
const int64_t t_off = t % block_size; const size_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx]; const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f; Tcompute score = 0.0f;
......
...@@ -25,6 +25,7 @@ public: ...@@ -25,6 +25,7 @@ public:
size_t total_q_tokens; size_t total_q_tokens;
ptrdiff_t q_stride; ptrdiff_t q_stride;
ptrdiff_t q_head_stride;
ptrdiff_t kv_block_stride; ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride; ptrdiff_t kv_head_stride;
ptrdiff_t o_stride; ptrdiff_t o_stride;
...@@ -35,7 +36,7 @@ public: ...@@ -35,7 +36,7 @@ public:
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc, infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc, infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc, infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) { float scale) {
...@@ -47,7 +48,7 @@ public: ...@@ -47,7 +48,7 @@ public:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || history_lens_desc->dtype() != INFINI_DTYPE_I64) { if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -57,7 +58,7 @@ public: ...@@ -57,7 +58,7 @@ public:
auto k_shape = k_cache_desc->shape(); auto k_shape = k_cache_desc->shape();
auto v_shape = v_cache_desc->shape(); auto v_shape = v_cache_desc->shape();
auto block_tables_shape = block_tables_desc->shape(); auto block_tables_shape = block_tables_desc->shape();
auto history_lens_shape = history_lens_desc->shape(); auto seq_lens_shape = seq_lens_desc->shape();
auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape(); auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape();
if (k_shape.size() != 4 || v_shape.size() != 4) { if (k_shape.size() != 4 || v_shape.size() != 4) {
...@@ -68,10 +69,11 @@ public: ...@@ -68,10 +69,11 @@ public:
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
if (history_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) { if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
if (cum_seq_lens_q_shape[0] != history_lens_shape[0] + 1) {
if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) {
return INFINI_STATUS_BAD_PARAM; return INFINI_STATUS_BAD_PARAM;
} }
...@@ -88,13 +90,13 @@ public: ...@@ -88,13 +90,13 @@ public:
return INFINI_STATUS_BAD_PARAM; return INFINI_STATUS_BAD_PARAM;
} }
size_t num_seqs = history_lens_shape[0]; size_t num_seqs = seq_lens_shape[0];
size_t num_kv_heads = k_shape[1]; size_t num_kv_heads = k_shape[1];
size_t block_size = k_shape[2]; size_t block_size = k_shape[2];
size_t max_num_blocks_per_seq = block_tables_shape[1]; size_t max_num_blocks_per_seq = block_tables_shape[1];
ptrdiff_t q_stride = q_desc->stride(0); ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t q_head_stride = q_desc->stride(1);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0); ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1); ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0); ptrdiff_t o_stride = out_desc->stride(0);
...@@ -110,6 +112,7 @@ public: ...@@ -110,6 +112,7 @@ public:
max_num_blocks_per_seq, max_num_blocks_per_seq,
total_q_tokens, total_q_tokens,
q_stride, q_stride,
q_head_stride,
kv_block_stride, kv_block_stride,
kv_head_stride, kv_head_stride,
o_stride}); o_stride});
......
...@@ -12,7 +12,7 @@ template <typename Tdata, typename Tcompute> ...@@ -12,7 +12,7 @@ template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill( infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *block_tables,
const int64_t *history_lens, const int64_t *seq_lens,
const int64_t *cum_seq_lens_q, const int64_t *cum_seq_lens_q,
const float *alibi_slopes, const float *alibi_slopes,
const size_t num_heads, const size_t num_heads,
...@@ -25,6 +25,8 @@ infiniStatus_t launchPagedAttentionPrefill( ...@@ -25,6 +25,8 @@ infiniStatus_t launchPagedAttentionPrefill(
const size_t head_size, const size_t head_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride, const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
cudaStream_t stream) { cudaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) { if (total_q_tokens == 0 || num_heads == 0) {
...@@ -37,10 +39,11 @@ infiniStatus_t launchPagedAttentionPrefill( ...@@ -37,10 +39,11 @@ infiniStatus_t launchPagedAttentionPrefill(
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute> op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, out, q, k_cache, v_cache,
block_tables, history_lens, cum_seq_lens_q, alibi_slopes, block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale, num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size, max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride, kv_block_stride, kv_head_stride,
q_stride, q_head_stride,
head_size, head_size,
num_seqs); num_seqs);
...@@ -65,14 +68,14 @@ infiniStatus_t Descriptor::create( ...@@ -65,14 +68,14 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc, infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc, infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc, infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) { float scale) {
auto info = PagedAttentionPrefillInfo::create( auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc, out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, history_lens_desc, block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc, cum_seq_lens_q_desc,
alibi_slopes_desc, scale); alibi_slopes_desc, scale);
...@@ -89,23 +92,24 @@ infiniStatus_t Descriptor::calculate( ...@@ -89,23 +92,24 @@ infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache, void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *block_tables,
const void *history_lens, const void *seq_lens,
const void *cum_seq_lens_q, const void *cum_seq_lens_q,
const void *alibi_slopes, const void *alibi_slopes,
void *stream_) const { void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_; cudaStream_t stream = (cudaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \ launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)history_lens, (const int64_t *)cum_seq_lens_q, \ (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \ (const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \ _info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \ _info.block_size, _info.total_q_tokens, \
_info.head_size, \ _info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \ _info.kv_block_stride, _info.kv_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream) stream)
if (_info.dtype == INFINI_DTYPE_F16) { if (_info.dtype == INFINI_DTYPE_F16) {
......
...@@ -14,7 +14,7 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( ...@@ -14,7 +14,7 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc, infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t history_lens_desc, infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc, infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc, infiniopTensorDescriptor_t alibi_slopes_desc,
float scale) { float scale) {
...@@ -27,14 +27,15 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( ...@@ -27,14 +27,15 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
handle, \ handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \ out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \
history_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale); seq_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia) CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
...@@ -50,8 +51,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( ...@@ -50,8 +51,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia) GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopPagedAttentionPrefill( __C infiniStatus_t infiniopPagedAttentionPrefill(
...@@ -59,7 +61,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill( ...@@ -59,7 +61,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache, void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *block_tables,
const void *history_lens, const void *seq_lens,
const void *cum_seq_lens_q, const void *cum_seq_lens_q,
const void *alibi_slopes, const void *alibi_slopes,
void *stream) { void *stream) {
...@@ -68,14 +70,15 @@ __C infiniStatus_t infiniopPagedAttentionPrefill( ...@@ -68,14 +70,15 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
case CASE: \ case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \ return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \ workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
history_lens, cum_seq_lens_q, alibi_slopes, stream); seq_lens, cum_seq_lens_q, alibi_slopes, stream);
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
...@@ -90,6 +93,7 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( ...@@ -90,6 +93,7 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia) DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
infiniopTensorDescriptor_t k_cache_desc, \ infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \ infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \ infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t history_lens_desc, \ infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t cum_seq_lens_q_desc, \ infiniopTensorDescriptor_t cum_seq_lens_q_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \ const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \ float scale); \
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
void *workspace, size_t workspace_size, \ void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \ void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, \ const void *block_tables, \
const void *history_lens, \ const void *seq_lens, \
const void *cum_seq_lens_q, \ const void *cum_seq_lens_q, \
const void *alibi_slopes, \ const void *alibi_slopes, \
void *stream) const; \ void *stream) const; \
......
import sys
import os import os
import sys
import torch import torch
import infinicore import infinicore
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import ( from framework import (
BaseOperatorTest, BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner, GenericTestRunner,
TensorInitializer, TensorInitializer,
TensorSpec,
TestCase,
) )
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds) # Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
...@@ -71,16 +73,17 @@ def parse_test_cases(): ...@@ -71,16 +73,17 @@ def parse_test_cases():
scale = head_size**-0.5 scale = head_size**-0.5
num_blocks = 8192 num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size) manager = SimpleCacheManager(num_blocks, block_size)
current_history_lens = torch.zeros(num_seqs, dtype=torch.int64) kv_lens = torch.zeros(num_seqs, dtype=torch.int64)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds): for r in range(num_rounds):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int64) q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int64)
kv_lens = kv_lens + q_lens
total_q_tokens = q_lens.sum().item() total_q_tokens = q_lens.sum().item()
cu_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64) cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64)
cu_seqlens_q[1:] = torch.cumsum(q_lens, dim=0) cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0)
query_base = torch.randn((total_q_tokens, num_heads, head_size)) query_base = torch.randn((total_q_tokens, num_heads, head_size))
...@@ -89,8 +92,7 @@ def parse_test_cases(): ...@@ -89,8 +92,7 @@ def parse_test_cases():
p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item()) p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item())
round_block_tables_list.append(p_blocks) round_block_tables_list.append(p_blocks)
h_len = current_history_lens[i].item() h_len = kv_lens[i].item() - q_lens[i].item()
q_start = cu_seqlens_q[i].item()
for t in range(q_lens[i].item()): for t in range(q_lens[i].item()):
logical_pos = h_len + t logical_pos = h_len + t
...@@ -135,15 +137,15 @@ def parse_test_cases(): ...@@ -135,15 +137,15 @@ def parse_test_cases():
dtype=infinicore.int64, dtype=infinicore.int64,
), ),
TensorSpec.from_tensor( TensorSpec.from_tensor(
current_history_lens.shape, kv_lens.shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
set_tensor=current_history_lens.clone(), set_tensor=kv_lens.clone(),
dtype=infinicore.int64, dtype=infinicore.int64,
), ),
TensorSpec.from_tensor( TensorSpec.from_tensor(
cu_seqlens_q.shape, cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
set_tensor=cu_seqlens_q.clone(), set_tensor=cum_seqlens_q.clone(),
dtype=infinicore.int64, dtype=infinicore.int64,
), ),
], ],
...@@ -153,23 +155,21 @@ def parse_test_cases(): ...@@ -153,23 +155,21 @@ def parse_test_cases():
) )
) )
current_history_lens += q_lens
return test_cases return test_cases
def ref_paged_attention_multi_turn( def ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, scale query, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, scale
): ):
output = torch.zeros_like(query) output = torch.zeros_like(query)
num_seqs = len(history_lens) num_seqs = len(kv_lens)
block_size = k_cache.shape[2] block_size = k_cache.shape[2]
for i in range(num_seqs): for i in range(num_seqs):
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item() q_start, q_end = cum_seqlens_q[i].item(), cum_seqlens_q[i + 1].item()
cur_q = query[q_start:q_end] cur_q = query[q_start:q_end]
h_len = history_lens[i].item()
q_len = q_end - q_start q_len = q_end - q_start
h_len = kv_lens[i].item() - q_len
total_len = h_len + q_len total_len = h_len + q_len
table = block_tables[i] table = block_tables[i]
...@@ -206,12 +206,12 @@ class OpTest(BaseOperatorTest): ...@@ -206,12 +206,12 @@ class OpTest(BaseOperatorTest):
k_cache, k_cache,
v_cache, v_cache,
block_tables, block_tables,
history_lens, kv_lens,
cu_seqlens_q, cum_seqlens_q,
scale=1.0, scale=1.0,
): ):
return ref_paged_attention_multi_turn( return ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, scale query, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, scale
) )
def infinicore_operator( def infinicore_operator(
...@@ -220,8 +220,8 @@ class OpTest(BaseOperatorTest): ...@@ -220,8 +220,8 @@ class OpTest(BaseOperatorTest):
k_cache, k_cache,
v_cache, v_cache,
block_tables, block_tables,
history_lens, kv_lens,
cu_seqlens_q, cum_seqlens_q,
scale=1.0, scale=1.0,
): ):
out = infinicore.paged_attention_prefill( out = infinicore.paged_attention_prefill(
...@@ -229,8 +229,8 @@ class OpTest(BaseOperatorTest): ...@@ -229,8 +229,8 @@ class OpTest(BaseOperatorTest):
k_cache, k_cache,
v_cache, v_cache,
block_tables, block_tables,
history_lens, kv_lens,
cu_seqlens_q, cum_seqlens_q,
alibi_slopes=None, alibi_slopes=None,
scale=scale, scale=scale,
) )
......
import torch
import ctypes import ctypes
from ctypes import c_uint64 from ctypes import c_uint64
import torch
from libinfiniop import ( from libinfiniop import (
LIBINFINIOP, LIBINFINIOP,
InfiniDeviceNames,
InfiniDtype,
InfiniDtypeNames,
TestTensor, TestTensor,
get_test_devices, TestWorkspace,
check_error, check_error,
test_operator,
get_args,
debug, debug,
get_args,
get_test_devices,
get_tolerance, get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
TestWorkspace, profile_operation,
test_operator,
) )
# ============================================================================== # ==============================================================================
...@@ -81,8 +82,8 @@ def ref_paged_attention_multi_turn( ...@@ -81,8 +82,8 @@ def ref_paged_attention_multi_turn(
num_seqs = len(cum_seq_lens_q) - 1 num_seqs = len(cum_seq_lens_q) - 1
for i in range(num_seqs): for i in range(num_seqs):
num_new = cum_seq_lens_q[i + 1].item() - cum_seq_lens_q[i].item() num_new = cum_seq_lens_q[i + 1].item() - cum_seq_lens_q[i].item()
cache_len = seq_lens[i].item() total_len = seq_lens[i].item()
total_len = seq_lens[i].item() + num_new cache_len = seq_lens[i].item() - num_new
table = block_tables[i] table = block_tables[i]
keys_all, values_all = [], [] keys_all, values_all = [], []
...@@ -166,7 +167,7 @@ def test( ...@@ -166,7 +167,7 @@ def test(
cur_q_len = query_lens_cpu[i].item() cur_q_len = query_lens_cpu[i].item()
table, total_len = manager.allocate_slots(i, cur_q_len) table, total_len = manager.allocate_slots(i, cur_q_len)
cur_seq_lens = total_len - cur_q_len cur_seq_lens = total_len - cur_q_len
seq_lens_list.append(cur_seq_lens) seq_lens_list.append(total_len)
all_block_tables.append(table) all_block_tables.append(table)
# Simulated KV insertion # Simulated KV insertion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment