"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "e4717df71a4f45bf9f0ac88c6cd9846a0bc248dd"
Commit 0a2839a2 authored by zhushuang's avatar zhushuang
Browse files

issue/867 - feat: adjust paged_attention_prefill interface naming

parent 3b5afffe
......@@ -8,11 +8,45 @@ namespace infinicore::op {
class PagedAttentionPrefill {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float);
/**
* @brief PagedAttentionPrefill operator signature
* * Argument order:
* 1. out: Output tensor (Packed format)
* 2. q: Current Query tensor (Packed format)
* 3. k_cache: Physical Key cache (Paged format)
* 4. v_cache: Physical Value cache (Paged format)
* 5. block_tables: Mapping table from logical blocks to physical blocks
* 6. history_lens: Historical KV lengths (existing length of each sequence in cache)
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
* 8. alibi_slopes: ALiBi bias slopes (optional)
* 9. scale: Scaling factor (typically 1/sqrt(head_size))
*/
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
Tensor paged_attention_prefill(Tensor q,
Tensor k_cache,
Tensor v_cache,
Tensor block_tables,
Tensor history_lens,
Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes,
float scale);
void paged_attention_prefill_(Tensor out,
Tensor q,
Tensor k_cache,
Tensor v_cache,
Tensor block_tables,
Tensor history_lens,
Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes,
float scale);
} // namespace infinicore::op
......@@ -11,15 +11,22 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* Shape: [total_q_tokens, num_heads, head_size]
* @param q_desc Descriptor for the query tensor (packed/flattened).
* Shape: [total_q_tokens, num_heads, head_size]
* @param k_cache_desc Descriptor for the global physical key cache.
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param v_cache_desc Descriptor for the global physical value cache.
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
* @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor.
* Shape: [batch_size, max_blocks_per_seq]
* @param history_lens_desc Descriptor for the KV history lengths of each sequence.
* Shape: [batch_size]
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
* Shape: [batch_size + 1]
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* Shape: [num_heads]
* @param scale The attention scaling factor (typically 1.0 / sqrt(head_size)).
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
......@@ -30,9 +37,8 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);
......@@ -52,11 +58,10 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param cache_lens Pointer to the total sequence lengths data.
* @param seq_lens Pointer to the current prefill sequence lengths data.
* @param offset Pointer to the sequence start offsets data.
* @param history_lens Pointer to the KV history lengths data.
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA/device stream for the operation.
* @param stream The device stream (e.g., cudaStream_t) for the operation.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttentionPrefill(
......@@ -68,9 +73,8 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *cache_lens,
const void *seq_lens,
const void *offset,
const void *history_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream);
......
......@@ -7,14 +7,15 @@ def paged_attention_prefill(
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
cache_lens: Tensor,
seq_lens: Tensor,
seq_offsets: Tensor,
history_lens: Tensor,
cu_seqlens_q: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
alibi_ptr = alibi_slopes._underlying if alibi_slopes is not None else None
if out is None:
return Tensor(
_infinicore.paged_attention_prefill(
......@@ -22,10 +23,9 @@ def paged_attention_prefill(
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
cache_lens._underlying,
seq_lens._underlying,
seq_offsets._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
history_lens._underlying,
cu_seqlens_q._underlying,
alibi_ptr,
scale,
)
)
......@@ -36,10 +36,9 @@ def paged_attention_prefill(
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
cache_lens._underlying,
seq_lens._underlying,
seq_offsets._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
history_lens._underlying,
cu_seqlens_q._underlying,
alibi_ptr,
scale,
)
......
......@@ -9,20 +9,32 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
return dispatcher_;
};
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables,
history_lens, cu_seqlens_q, alibi_slopes, scale);
}
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
return out;
}
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
}
} // namespace infinicore::op
......@@ -15,8 +15,11 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
}
});
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
......@@ -27,8 +30,13 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionPrefillDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(),
cache_lens->desc(), seq_lens->desc(), seq_offsets->desc(),
out->desc(),
q->desc(),
k_cache->desc(),
v_cache->desc(),
block_tables->desc(),
history_lens->desc(),
cu_seqlens_q->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
......@@ -41,8 +49,16 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopPagedAttentionPrefill(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(), seq_lens->data(), seq_offsets->data(),
desc,
workspace->data(),
workspace_size,
out->data(),
q->data(),
k_cache->data(),
v_cache->data(),
block_tables->data(),
history_lens->data(),
cu_seqlens_q->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
......
......@@ -11,6 +11,7 @@
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
......@@ -33,6 +34,7 @@ inline void bind(py::module &m) {
bind_matmul(m);
bind_mul(m);
bind_paged_attention(m);
bind_paged_attention_prefill(m);
bind_paged_caching(m);
bind_rearrange(m);
bind_rms_norm(m);
......
#pragma once
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_paged_attention_prefill(Tensor q,
Tensor k_cache,
Tensor v_cache,
Tensor block_tables,
Tensor history_lens,
Tensor cu_seqlens_q,
py::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::paged_attention_prefill(q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
}
void py_paged_attention_prefill_(Tensor out,
Tensor q,
Tensor k_cache,
Tensor v_cache,
Tensor block_tables,
Tensor history_lens,
Tensor cu_seqlens_q,
py::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
op::paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
}
inline void bind_paged_attention_prefill(py::module &m) {
m.def("paged_attention_prefill",
&ops::py_paged_attention_prefill,
py::arg("q"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("history_lens"),
py::arg("cu_seqlens_q"),
py::arg("alibi_slopes") = py::none(),
py::arg("scale") = 1.0,
R"doc(Paged attention prefill for packed variable-length queries.)doc");
m.def("paged_attention_prefill_",
&ops::py_paged_attention_prefill_,
py::arg("out"),
py::arg("q"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("history_lens"),
py::arg("cu_seqlens_q"),
py::arg("alibi_slopes") = py::none(),
py::arg("scale") = 1.0,
R"doc(In-place paged attention prefill for packed variable-length queries.)doc");
}
} // namespace infinicore::ops
......@@ -3,14 +3,13 @@
namespace op::paged_attention_prefill::cuda {
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *offset, size_t num_seqs) {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0, high = num_seqs - 1;
while (low <= high) {
size_t mid = (low + high) >> 1;
if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) {
if (token_idx >= (size_t)cum_seq_lens_q[mid] && token_idx < (size_t)cum_seq_lens_q[mid + 1]) {
return mid;
} else if (token_idx < offset[mid]) {
} else if (token_idx < (size_t)cum_seq_lens_q[mid]) {
high = mid - 1;
} else {
low = mid + 1;
......@@ -22,50 +21,43 @@ __device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *o
template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_, const int64_t *cache_lens_, const int64_t *seq_lens_,
const int64_t *block_tables_,
const int64_t *history_lens_,
const int64_t *cum_seq_lens_q_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const size_t head_size,
const int64_t *offset_,
const size_t num_seqs) {
// --- 使用 2D Grid 坐标 ---
const size_t global_token_idx = blockIdx.x; // 展平后的全局 token 索引
const size_t head_idx = blockIdx.y; // Head 索引
const size_t dim_idx = threadIdx.x; // Head 内部维度
// Grid : x -> token, y -> head
const size_t global_token_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;
const size_t dim_idx = threadIdx.x;
if (dim_idx >= head_size) {
return;
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
size_t seq_idx = find_seq_id(global_token_idx, offset_, num_seqs);
size_t seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs);
// --- 获取该 Sequence 本次 Prefill 的长度
const int64_t cur_new_len = seq_lens_[seq_idx];
size_t q_token_idx = global_token_idx - cum_seq_lens_q_[seq_idx];
// --- 该 token 在当前序列中的相对位置
size_t q_token_idx = global_token_idx - offset_[seq_idx];
const int64_t history_len = history_lens_[seq_idx];
const int64_t causal_limit = history_len + q_token_idx;
const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const Tdata *q_vec = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
// --- KV Cache 相关信息
const int64_t total_seq_len = cache_lens_[seq_idx];
const int64_t history_len = total_seq_len - cur_new_len;
const int64_t causal_limit = history_len + q_token_idx;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
// Pass 1: 计算 Score 并找最大值
Tcompute max_score = -FLT_MAX;
for (size_t t = 0; t <= causal_limit; ++t) {
for (int64_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
......@@ -73,7 +65,7 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute score = 0.0f;
for (size_t d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
score += static_cast<Tcompute>(q_vec[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
......@@ -84,9 +76,8 @@ __global__ void pagedAttentionPrefillKernel(
}
}
// Pass 2: 计算 Sum of Exp
Tcompute sum_exp = 0.0f;
for (size_t t = 0; t <= causal_limit; ++t) {
for (int64_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
......@@ -94,7 +85,7 @@ __global__ void pagedAttentionPrefillKernel(
Tcompute score = 0.0f;
for (size_t d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
score += static_cast<Tcompute>(q_vec[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
......@@ -103,10 +94,9 @@ __global__ void pagedAttentionPrefillKernel(
sum_exp += expf(static_cast<float>(score - max_score));
}
// Pass 3: 加权求和得到输出
Tcompute acc = 0.0f;
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
for (size_t t = 0; t <= causal_limit; ++t) {
for (int64_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
......@@ -114,7 +104,7 @@ __global__ void pagedAttentionPrefillKernel(
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
for (size_t d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
score += static_cast<Tcompute>(q_vec[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
......
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_INFO_H__
#define __INFINIOP_PAGED_ATTENTION_PREFILL_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
......@@ -35,9 +35,8 @@ public:
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
......@@ -47,39 +46,54 @@ public:
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (offset_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || history_lens_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
std::cerr << "[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet." << std::endl;
}
auto k_shape = k_cache_desc->shape();
auto v_shape = v_cache_desc->shape();
auto block_tables_shape = block_tables_desc->shape();
auto history_lens_shape = history_lens_desc->shape();
auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape();
if (k_shape.size() != 4 || v_shape.size() != 4) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_shape.size() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (history_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (cum_seq_lens_q_shape[0] != history_lens_shape[0] + 1) {
return INFINI_STATUS_BAD_PARAM;
}
// Q shape: [total_tokens, heads, dim] (3D)
// Q shape: [total_tokens, heads, dim]
auto q_shape = q_desc->shape();
if (q_shape.size() < 3) {
if (q_shape.size() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t total_q_tokens = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
size_t num_heads = q_shape[q_shape.size() - 2];
size_t head_size = q_shape[q_shape.size() - 1];
if (head_size != 128) {
std::cerr << "[Error] PagedAttentionPrefill head_size = 128 supported, got " << head_size << std::endl;
return INFINI_STATUS_BAD_TENSOR_SHAPE;
if (head_size > 1024) {
return INFINI_STATUS_BAD_PARAM;
}
// 从 seq_lens 获取 num_seqs
size_t num_seqs = seq_lens_desc->shape()[0];
size_t num_seqs = history_lens_shape[0];
auto k_cache_shape = k_cache_desc->shape();
size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2];
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
size_t num_kv_heads = k_shape[1];
size_t block_size = k_shape[2];
size_t max_num_blocks_per_seq = block_tables_shape[1];
// 提取步长,需要保持多个请求的 Q 连续
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
......
......@@ -8,14 +8,12 @@
#include "../cuda/kernel.cuh"
#include "paged_attention_prefill_nvidia.cuh"
// ==============================================================================
// Host wrapper to launch the global kernel
// ==============================================================================
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *cache_lens, const int64_t *seq_lens,
const int64_t *offset,
const int64_t *block_tables,
const int64_t *history_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
......@@ -24,36 +22,27 @@ infiniStatus_t launchPagedAttentionPrefill(
const size_t max_num_blocks_per_seq,
const size_t block_size,
const size_t total_q_tokens,
const ptrdiff_t q_stride,
const size_t head_size,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride,
const size_t head_size,
cudaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// 使用 2D Grid: X轴是所有 Token,Y轴是所有 Head
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, cache_lens, seq_lens, alibi_slopes,
block_tables, history_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride,
head_size,
offset, num_seqs);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA Kernel Launch Failed: " << cudaGetErrorString(err) << std::endl;
return INFINI_STATUS_INTERNAL_ERROR;
}
num_seqs);
return INFINI_STATUS_SUCCESS;
}
......@@ -76,16 +65,17 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, cache_lens_desc, seq_lens_desc,
offset_desc,
alibi_slopes_desc, scale);
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, history_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
......@@ -98,28 +88,24 @@ infiniStatus_t Descriptor::create(
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *seq_lens,
const void *offset,
const void *block_tables,
const void *history_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_info.head_size > 1024) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)cache_lens, (const int64_t *)seq_lens, \
(const int64_t *)offset, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
_info.head_size, \
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)history_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
......
......@@ -14,9 +14,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
infiniopTensorDescriptor_t history_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale) {
......@@ -27,8 +26,8 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, \
seq_lens_desc, offset_desc, alibi_opt, scale);
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \
history_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
......@@ -59,8 +58,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *seq_lens,
const void *offset,
const void *block_tables,
const void *history_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream) {
......@@ -68,7 +68,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
cache_lens, seq_lens, offset, alibi_slopes, stream);
history_lens, cum_seq_lens_q, alibi_slopes, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
......
......@@ -4,53 +4,53 @@
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t cache_lens_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t offset_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, const void *cache_lens, const void *seq_lens, \
const void *offset, \
const void *alibi_slopes, \
void *stream) const; \
}; \
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t history_lens_desc, \
infiniopTensorDescriptor_t cum_seq_lens_q_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, \
const void *history_lens, \
const void *cum_seq_lens_q, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_PREFILL_H
import sys
import os
import torch
import infinicore
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
TensorInitializer,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
_TEST_CASES_DATA = [
(1, 1, 1, 128, 8, 16, 1),
(1, 4, 4, 128, 8, 16, 4),
(2, 8, 8, 128, 16, 32, 2),
(4, 16, 16, 128, 8, 64, 3),
(8, 64, 64, 128, 8, 16, 5),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-4}, # float32 调优容限
infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.request_to_blocks = {}
self.request_to_len = {}
def allocate_slots(self, request_id, num_new_tokens):
if request_id not in self.request_to_len:
self.request_to_len[request_id] = 0
self.request_to_blocks[request_id] = []
start_pos = self.request_to_len[request_id]
new_total_len = start_pos + num_new_tokens
needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
added_blocks = needed_blocks - len(self.request_to_blocks[request_id])
for _ in range(added_blocks):
self.request_to_blocks[request_id].append(self.free_blocks.pop(0))
self.request_to_len[request_id] = new_total_len
return self.request_to_blocks[request_id], new_total_len
def parse_test_cases():
test_cases = []
for (
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_step_len,
num_rounds,
) in _TEST_CASES_DATA:
scale = head_size**-0.5
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
current_history_lens = torch.zeros(num_seqs, dtype=torch.int64)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int64)
total_q_tokens = q_lens.sum().item()
cu_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64)
cu_seqlens_q[1:] = torch.cumsum(q_lens, dim=0)
query_base = torch.randn((total_q_tokens, num_heads, head_size))
round_block_tables_list = []
for i in range(num_seqs):
p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item())
round_block_tables_list.append(p_blocks)
h_len = current_history_lens[i].item()
q_start = cu_seqlens_q[i].item()
for t in range(q_lens[i].item()):
logical_pos = h_len + t
b_id = p_blocks[logical_pos // block_size]
off = logical_pos % block_size
persistent_k[b_id, :, off, :] = torch.randn(num_kv_heads, head_size)
persistent_v[b_id, :, off, :] = torch.randn(num_kv_heads, head_size)
max_blks = max(len(t) for t in round_block_tables_list)
padded_tables = torch.tensor(
[t + [0] * (max_blks - len(t)) for t in round_block_tables_list]
)
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype)
test_cases.append(
TestCase(
inputs=[
TensorSpec.from_tensor(
query_base.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=query_base.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_k.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_k.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_v.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_v.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
padded_tables.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=padded_tables.clone(),
dtype=infinicore.int64,
),
TensorSpec.from_tensor(
current_history_lens.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=current_history_lens.clone(),
dtype=infinicore.int64,
),
TensorSpec.from_tensor(
cu_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=cu_seqlens_q.clone(),
dtype=infinicore.int64,
),
],
kwargs={"scale": scale},
tolerance=tolerance,
description=f"PagedAttentionPrefill_Round_{r}_{str(dtype).split('.')[-1]}",
)
)
current_history_lens += q_lens
return test_cases
def ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, scale
):
output = torch.zeros_like(query)
num_seqs = len(history_lens)
block_size = k_cache.shape[2]
for i in range(num_seqs):
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
cur_q = query[q_start:q_end]
h_len = history_lens[i].item()
q_len = q_end - q_start
total_len = h_len + q_len
table = block_tables[i]
keys, values = [], []
for j in range(total_len):
b_id = table[j // block_size].item()
off = j % block_size
keys.append(k_cache[b_id, :, off, :])
values.append(v_cache[b_id, :, off, :])
K = torch.stack(keys, dim=0)
V = torch.stack(values, dim=0)
scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale
mask = torch.full((q_len, total_len), float("-inf"), device=query.device)
for t in range(q_len):
mask[t, : h_len + t + 1] = 0.0
attn = torch.softmax(scores + mask.unsqueeze(0), dim=-1).to(query.dtype)
output[q_start:q_end] = torch.einsum("hqk,khd->qhd", attn, V)
return output
class OpTest(BaseOperatorTest):
def __init__(self):
super().__init__("PagedAttentionPrefill")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(
self,
query,
k_cache,
v_cache,
block_tables,
history_lens,
cu_seqlens_q,
scale=1.0,
):
return ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, scale
)
def infinicore_operator(
self,
query,
k_cache,
v_cache,
block_tables,
history_lens,
cu_seqlens_q,
scale=1.0,
):
out = infinicore.paged_attention_prefill(
query,
k_cache,
v_cache,
block_tables,
history_lens,
cu_seqlens_q,
alibi_slopes=None,
scale=scale,
)
infinicore.sync_stream()
return out
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
......@@ -1115,7 +1115,6 @@ def paged_attention_prefill_(lib):
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
......@@ -1139,7 +1138,6 @@ def paged_attention_prefill_(lib):
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPagedAttentionPrefillDescriptor.restype = c_int32
......
......@@ -74,14 +74,15 @@ class SimpleCacheManager:
def ref_paged_attention_multi_turn(
query_new, k_cache, v_cache, block_tables, seq_lens, new_lens, offset, scale
query_new, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, scale
):
block_size = k_cache.shape[2]
outputs = torch.zeros_like(query_new)
for i in range(len(offset) - 1):
total_len = seq_lens[i].item()
num_new = new_lens[i].item()
history_len = total_len - num_new
num_seqs = len(cum_seq_lens_q) - 1
for i in range(num_seqs):
num_new = cum_seq_lens_q[i + 1].item() - cum_seq_lens_q[i].item()
cache_len = seq_lens[i].item()
total_len = seq_lens[i].item() + num_new
table = block_tables[i]
keys_all, values_all = [], []
......@@ -93,19 +94,19 @@ def ref_paged_attention_multi_turn(
K = torch.stack(keys_all, dim=0)
V = torch.stack(values_all, dim=0)
Q = query_new[offset[i] : offset[i] + num_new, :, :]
Q = query_new[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :]
scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale
mask = torch.full((num_new, total_len), float("-inf"), device=Q.device)
for q_idx in range(num_new):
mask[q_idx, : history_len + q_idx + 1] = 0.0
mask[q_idx, : cache_len + q_idx + 1] = 0.0
scores = scores + mask.unsqueeze(0)
attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, V)
outputs[offset[i] : offset[i] + num_new, :, :] = out
outputs[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :] = out
return outputs
......@@ -147,43 +148,43 @@ def test(
# Multi-turn testing loop
for r in range(num_rounds):
# Prepare dynamic inputs for this round
seq_lens_cpu = torch.randint(
query_lens_cpu = torch.randint(
1, max_step_len + 1, (num_seqs,), dtype=torch.int64
)
q_total_tokens = seq_lens_cpu.sum().item()
q_total_tokens = query_lens_cpu.sum().item()
q_packed_tensors = torch.zeros(q_total_tokens, num_heads, head_size)
cache_lens_list = []
seq_lens_list = []
all_block_tables = []
offset_list = []
cur_offset = 0
cum_seq_lens_q_list = []
cum_q_lens = 0
for i in range(num_seqs):
offset_list.append(cur_offset)
cum_seq_lens_q_list.append(cum_q_lens)
cur_new_len = seq_lens_cpu[i].item()
table, cache_len = manager.allocate_slots(i, cur_new_len)
cache_lens_list.append(cache_len)
cur_q_len = query_lens_cpu[i].item()
table, total_len = manager.allocate_slots(i, cur_q_len)
cur_seq_lens = total_len - cur_q_len
seq_lens_list.append(cur_seq_lens)
all_block_tables.append(table)
# Simulated KV insertion
k_new = torch.randn(cur_new_len, num_kv_heads, head_size)
v_new = torch.randn(cur_new_len, num_kv_heads, head_size)
q_val = torch.randn(cur_new_len, num_heads, head_size)
q_packed_tensors[cur_offset : cur_offset + cur_new_len] = q_val
k_new = torch.randn(cur_q_len, num_kv_heads, head_size)
v_new = torch.randn(cur_q_len, num_kv_heads, head_size)
q_val = torch.randn(cur_q_len, num_heads, head_size)
q_packed_tensors[cum_q_lens : cum_q_lens + cur_q_len] = q_val
cur_offset = cur_offset + cur_new_len
cum_q_lens = cum_q_lens + cur_q_len
history_len = cache_len - cur_new_len
for t in range(cur_new_len):
logical_pos = history_len + t
for t in range(cur_q_len):
logical_pos = cur_seq_lens + t
b_id = table[logical_pos // block_size]
off = logical_pos % block_size
k_cache.torch_tensor()[b_id, :, off, :] = k_new[t]
v_cache.torch_tensor()[b_id, :, off, :] = v_new[t]
offset_list.append(cur_offset)
cum_seq_lens_q_list.append(cum_q_lens)
k_cache.actual_tensor().copy_(k_cache._torch_tensor)
v_cache.actual_tensor().copy_(v_cache._torch_tensor)
......@@ -193,13 +194,14 @@ def test(
out = TestTensor.from_torch(q_packed_tensors, dtype, device)
out.actual_tensor().zero_()
cache_lens = TestTensor.from_torch(
torch.tensor(cache_lens_list, dtype=torch.int64), InfiniDtype.I64, device
seq_lens = TestTensor.from_torch(
torch.tensor(seq_lens_list, dtype=torch.int64), InfiniDtype.I64, device
)
seq_lens = TestTensor.from_torch(seq_lens_cpu, InfiniDtype.I64, device)
offset = TestTensor.from_torch(
torch.tensor(offset_list, dtype=torch.int64), InfiniDtype.I64, device
cum_seq_lens_q = TestTensor.from_torch(
torch.tensor(cum_seq_lens_q_list, dtype=torch.int64),
InfiniDtype.I64,
device,
)
max_blocks = max(len(t) for t in all_block_tables)
......@@ -215,9 +217,8 @@ def test(
k_cache.torch_tensor(),
v_cache.torch_tensor(),
block_tables.torch_tensor(),
cache_lens.torch_tensor(),
seq_lens.torch_tensor(),
offset.torch_tensor(),
cum_seq_lens_q.torch_tensor(),
scale,
)
......@@ -234,10 +235,9 @@ def test(
k_cache.descriptor,
v_cache.descriptor,
block_tables.descriptor,
cache_lens.descriptor,
seq_lens.descriptor,
offset.descriptor,
None, # alibi_slopes_desc
cum_seq_lens_q.descriptor,
None,
scale,
)
)
......@@ -261,9 +261,8 @@ def test(
k_cache.data(),
v_cache.data(),
block_tables.data(),
cache_lens.data(),
seq_lens.data(),
offset.data(),
cum_seq_lens_q.data(),
None,
None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment