Unverified Commit caa61e9e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #868 from InfiniTensor/issue/847

Issue/847  paged attention prefill一段式接口
parents 31c0af3f 99b940b2
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "../tensor.hpp" #include "../tensor.hpp"
#include <optional>
#include <type_traits> #include <type_traits>
namespace infinicore { namespace infinicore {
...@@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) { ...@@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) {
} }
} }
// Specialization for optional
template <typename T>
inline void hash_combine(size_t &seed, const std::optional<T> &opt) {
hash_combine(seed, opt.has_value());
if (opt) {
hash_combine(seed, *opt);
}
}
// Specialization for std::string // Specialization for std::string
inline void hash_combine(size_t &seed, const std::string &str) { inline void hash_combine(size_t &seed, const std::string &str) {
hash_combine(seed, std::hash<std::string>{}(str)); hash_combine(seed, std::hash<std::string>{}(str));
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
#include "ops/ones.hpp" #include "ops/ones.hpp"
#include "ops/paged_attention.hpp" #include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp" #include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp" #include "ops/random_sample.hpp"
#include "ops/rearrange.hpp" #include "ops/rearrange.hpp"
......
...@@ -9,10 +9,10 @@ namespace infinicore::op { ...@@ -9,10 +9,10 @@ namespace infinicore::op {
class PagedAttention { class PagedAttention {
public: public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float); using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float); static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher(); static common::OpDispatcher<schema> &dispatcher();
}; };
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale); Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale); void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op } // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
class PagedAttentionPrefill {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
...@@ -45,6 +45,7 @@ from infinicore.ops.matmul import matmul ...@@ -45,6 +45,7 @@ from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention from infinicore.ops.paged_attention import paged_attention
from infinicore.ops.paged_attention_prefill import paged_attention_prefill
from infinicore.ops.paged_caching import paged_caching from infinicore.ops.paged_caching import paged_caching
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.ops.squeeze import squeeze from infinicore.ops.squeeze import squeeze
...@@ -119,6 +120,7 @@ __all__ = [ ...@@ -119,6 +120,7 @@ __all__ = [
"from_torch", "from_torch",
"paged_caching", "paged_caching",
"paged_attention", "paged_attention",
"paged_attention_prefill",
"ones", "ones",
"strided_empty", "strided_empty",
"strided_from_blob", "strided_from_blob",
......
...@@ -7,7 +7,7 @@ def paged_attention( ...@@ -7,7 +7,7 @@ def paged_attention(
k_cache: Tensor, k_cache: Tensor,
v_cache: Tensor, v_cache: Tensor,
block_tables: Tensor, block_tables: Tensor,
seq_lens: Tensor, cache_lens: Tensor,
alibi_slopes: Tensor | None = None, alibi_slopes: Tensor | None = None,
scale: float = 1.0, scale: float = 1.0,
*, *,
...@@ -20,7 +20,7 @@ def paged_attention( ...@@ -20,7 +20,7 @@ def paged_attention(
k_cache._underlying, k_cache._underlying,
v_cache._underlying, v_cache._underlying,
block_tables._underlying, block_tables._underlying,
seq_lens._underlying, cache_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None, alibi_slopes._underlying if alibi_slopes is not None else None,
scale, scale,
) )
...@@ -32,7 +32,7 @@ def paged_attention( ...@@ -32,7 +32,7 @@ def paged_attention(
k_cache._underlying, k_cache._underlying,
v_cache._underlying, v_cache._underlying,
block_tables._underlying, block_tables._underlying,
seq_lens._underlying, cache_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None, alibi_slopes._underlying if alibi_slopes is not None else None,
scale, scale,
) )
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def paged_attention_prefill(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
cache_lens: Tensor,
seq_lens: Tensor,
seq_offsets: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.paged_attention_prefill(
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
cache_lens._underlying,
seq_lens._underlying,
seq_offsets._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)
_infinicore.paged_attention_prefill_(
out._underlying,
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
cache_lens._underlying,
seq_lens._underlying,
seq_offsets._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
return out
...@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() { ...@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return dispatcher_; return dispatcher_;
}; };
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) { void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
infinicore::context::setDevice(out->device()); infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
} }
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) { Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
return out; return out;
} }
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) { void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
} }
} // namespace infinicore::op } // namespace infinicore::op
...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches( ...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
} }
}); });
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) { void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens); size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
auto device = context::getDevice(); auto device = context::getDevice();
auto &cache = caches.getCache(device); auto &cache = caches.getCache(device);
...@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc ...@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor( INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc, context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(), out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale)); scale));
cache.put(seed, desc); cache.put(seed, desc);
...@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc ...@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR(infiniopPagedAttention( INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size, desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(), out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream())); context::getStream()));
} }
......
#include "infinicore/ops/paged_attention_prefill.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::dispatcher() {
static common::OpDispatcher<PagedAttentionPrefill::schema> dispatcher_;
return dispatcher_;
};
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
}
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
return out;
}
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <infiniop.h>
namespace infinicore::op::paged_attention_prefill_impl::infiniop {
thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t> caches(
100, // capacity
[](infiniopPagedAttentionPrefillDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionPrefillDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopPagedAttentionPrefillDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionPrefillDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(),
cache_lens->desc(), seq_lens->desc(), seq_offsets->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionPrefillWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopPagedAttentionPrefill(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(), seq_lens->data(), seq_offsets->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
static bool registered = []() {
PagedAttentionPrefill::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::paged_attention_prefill_impl::infiniop
...@@ -8,21 +8,21 @@ namespace py = pybind11; ...@@ -8,21 +8,21 @@ namespace py = pybind11;
namespace infinicore::ops { namespace infinicore::ops {
Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) { Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt; std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) { if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>(); alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
} }
return op::paged_attention(q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale); return op::paged_attention(q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale);
} }
void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) { void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt; std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) { if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>(); alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
} }
op::paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale); op::paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale);
} }
inline void bind_paged_attention(py::module &m) { inline void bind_paged_attention(py::module &m) {
...@@ -32,7 +32,7 @@ inline void bind_paged_attention(py::module &m) { ...@@ -32,7 +32,7 @@ inline void bind_paged_attention(py::module &m) {
py::arg("k_cache"), py::arg("k_cache"),
py::arg("v_cache"), py::arg("v_cache"),
py::arg("block_tables"), py::arg("block_tables"),
py::arg("seq_lens"), py::arg("cache_lens"),
py::arg("alibi_slopes"), py::arg("alibi_slopes"),
py::arg("scale"), py::arg("scale"),
R"doc(Paged attention of query and key cache tensors.)doc"); R"doc(Paged attention of query and key cache tensors.)doc");
...@@ -44,7 +44,7 @@ inline void bind_paged_attention(py::module &m) { ...@@ -44,7 +44,7 @@ inline void bind_paged_attention(py::module &m) {
py::arg("k_cache"), py::arg("k_cache"),
py::arg("v_cache"), py::arg("v_cache"),
py::arg("block_tables"), py::arg("block_tables"),
py::arg("seq_lens"), py::arg("cache_lens"),
py::arg("alibi_slopes"), py::arg("alibi_slopes"),
py::arg("scale"), py::arg("scale"),
R"doc(In-place paged attention of query and key cache tensors.)doc"); R"doc(In-place paged attention of query and key cache tensors.)doc");
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
namespace op::paged_attention_prefill::cuda { namespace op::paged_attention_prefill::cuda {
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence // 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__ __forceinline__ int find_seq_id(int token_idx, const int64_t *offset, int num_seqs) { __device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *offset, size_t num_seqs) {
int low = 0, high = num_seqs - 1; size_t low = 0, high = num_seqs - 1;
while (low <= high) { while (low <= high) {
int mid = (low + high) >> 1; size_t mid = (low + high) >> 1;
if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) { if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) {
return mid; return mid;
} else if (token_idx < offset[mid]) { } else if (token_idx < offset[mid]) {
...@@ -32,22 +32,22 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -32,22 +32,22 @@ __global__ void pagedAttentionPrefillKernel(
const size_t num_seqs) { const size_t num_seqs) {
// --- 使用 2D Grid 坐标 --- // --- 使用 2D Grid 坐标 ---
const int global_token_idx = blockIdx.x; // 展平后的全局 token 索引 const size_t global_token_idx = blockIdx.x; // 展平后的全局 token 索引
const int head_idx = blockIdx.y; // Head 索引 const size_t head_idx = blockIdx.y; // Head 索引
const int dim_idx = threadIdx.x; // Head 内部维度 const size_t dim_idx = threadIdx.x; // Head 内部维度
if (dim_idx >= head_size) { if (dim_idx >= head_size) {
return; return;
} }
// --- 通过二分查找 offset 找到所属的 seq_idx --- // --- 通过二分查找 offset 找到所属的 seq_idx ---
int seq_idx = find_seq_id(global_token_idx, offset_, num_seqs); size_t seq_idx = find_seq_id(global_token_idx, offset_, num_seqs);
// --- 获取该 Sequence 本次 Prefill 的长度 // --- 获取该 Sequence 本次 Prefill 的长度
const int64_t cur_new_len = seq_lens_[seq_idx]; const int64_t cur_new_len = seq_lens_[seq_idx];
// --- 该 token 在当前序列中的相对位置 // --- 该 token 在当前序列中的相对位置
int q_token_idx = global_token_idx - offset_[seq_idx]; size_t q_token_idx = global_token_idx - offset_[seq_idx];
const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size; const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
...@@ -65,14 +65,14 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -65,14 +65,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 1: 计算 Score 并找最大值 // Pass 1: 计算 Score 并找最大值
Tcompute max_score = -FLT_MAX; Tcompute max_score = -FLT_MAX;
for (int t = 0; t <= causal_limit; ++t) { for (size_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size; const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size; const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx]; const int64_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f; Tcompute score = 0.0f;
for (int d = 0; d < head_size; ++d) { for (size_t d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]); score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
} }
score *= static_cast<Tcompute>(scale); score *= static_cast<Tcompute>(scale);
...@@ -86,14 +86,14 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -86,14 +86,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 2: 计算 Sum of Exp // Pass 2: 计算 Sum of Exp
Tcompute sum_exp = 0.0f; Tcompute sum_exp = 0.0f;
for (int t = 0; t <= causal_limit; ++t) { for (size_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size; const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size; const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx]; const int64_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f; Tcompute score = 0.0f;
for (int d = 0; d < head_size; ++d) { for (size_t d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]); score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
} }
score *= static_cast<Tcompute>(scale); score *= static_cast<Tcompute>(scale);
...@@ -106,14 +106,14 @@ __global__ void pagedAttentionPrefillKernel( ...@@ -106,14 +106,14 @@ __global__ void pagedAttentionPrefillKernel(
// Pass 3: 加权求和得到输出 // Pass 3: 加权求和得到输出
Tcompute acc = 0.0f; Tcompute acc = 0.0f;
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f); Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
for (int t = 0; t <= causal_limit; ++t) { for (size_t t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size; const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size; const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx]; const int64_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f; Tcompute score = 0.0f;
for (int d = 0; d < head_size; ++d) { for (size_t d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]); score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
} }
score *= static_cast<Tcompute>(scale); score *= static_cast<Tcompute>(scale);
......
...@@ -62,7 +62,7 @@ def parse_test_cases(): ...@@ -62,7 +62,7 @@ def parse_test_cases():
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing
seq_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64) cache_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64)
block_tables = torch.arange( block_tables = torch.arange(
0, num_seqs * max_blocks_per_seq, dtype=torch.int64 0, num_seqs * max_blocks_per_seq, dtype=torch.int64
...@@ -75,7 +75,7 @@ def parse_test_cases(): ...@@ -75,7 +75,7 @@ def parse_test_cases():
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
block_tables_shape = block_tables.shape block_tables_shape = block_tables.shape
seq_lens_shape = seq_lens_torch.shape cache_lens_shape = cache_lens_torch.shape
# Generate test cases for all data types # Generate test cases for all data types
for dtype in _TENSOR_DTYPES: for dtype in _TENSOR_DTYPES:
...@@ -91,10 +91,10 @@ def parse_test_cases(): ...@@ -91,10 +91,10 @@ def parse_test_cases():
set_tensor=block_tables, set_tensor=block_tables,
dtype=infinicore.int64, dtype=infinicore.int64,
) )
seq_lens_spec = TensorSpec.from_tensor( cache_lens_spec = TensorSpec.from_tensor(
seq_lens_shape, cache_lens_shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
set_tensor=seq_lens_torch, set_tensor=cache_lens_torch,
dtype=infinicore.int64, dtype=infinicore.int64,
) )
...@@ -108,7 +108,7 @@ def parse_test_cases(): ...@@ -108,7 +108,7 @@ def parse_test_cases():
k_cache_spec, k_cache_spec,
v_cache_spec, v_cache_spec,
block_tables_spec, block_tables_spec,
seq_lens_spec, cache_lens_spec,
], ],
kwargs={"alibi_slopes": None, "scale": scale}, kwargs={"alibi_slopes": None, "scale": scale},
output_spec=None, output_spec=None,
...@@ -132,7 +132,7 @@ def ref_masked_attention(query, key, value, scale, attn_mask=None): ...@@ -132,7 +132,7 @@ def ref_masked_attention(query, key, value, scale, attn_mask=None):
def ref_single_query_cached_kv_attention( def ref_single_query_cached_kv_attention(
query, key_cache, value_cache, block_tables, seq_lens, alibi_slopes, scale query, key_cache, value_cache, block_tables, cache_lens, alibi_slopes, scale
): ):
# Reference implementation for paged attention, iterating through each sequence. # Reference implementation for paged attention, iterating through each sequence.
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -143,7 +143,7 @@ def ref_single_query_cached_kv_attention( ...@@ -143,7 +143,7 @@ def ref_single_query_cached_kv_attention(
for i in range(num_seqs): for i in range(num_seqs):
q = query[i].unsqueeze(0) q = query[i].unsqueeze(0)
seq_len = seq_lens[i].item() seq_len = cache_lens[i].item()
block_table = block_tables[i] block_table = block_tables[i]
keys_lst, values_lst = [], [] keys_lst, values_lst = [], []
......
...@@ -55,7 +55,7 @@ target("infiniop-nvidia") ...@@ -55,7 +55,7 @@ target("infiniop-nvidia")
end end
end end
add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations") add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function")
local arch_opt = get_config("cuda_arch") local arch_opt = get_config("cuda_arch")
if arch_opt and type(arch_opt) == "string" then if arch_opt and type(arch_opt) == "string" then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment