Commit 1c18c046 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/979 optimize paged attention

parent 97eced0e
This diff is collapsed.
......@@ -13,92 +13,171 @@ class PagedAttentionInfo {
PagedAttentionInfo() = default;
public:
// --- Data Types and Scale ---
infiniDtype_t dtype;
infiniDtype_t index_dtype;
float scale;
// --- Shape Dimensions ---
size_t num_seqs;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t page_block_size;
size_t max_num_blocks_per_seq;
// --- Strides for Memory Layout ---
ptrdiff_t q_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t k_batch_stride;
ptrdiff_t k_row_stride;
ptrdiff_t k_head_stride;
ptrdiff_t v_batch_stride;
ptrdiff_t v_row_stride;
ptrdiff_t v_head_stride;
ptrdiff_t o_stride;
ptrdiff_t block_table_batch_stride;
ptrdiff_t cache_lens_stride;
static utils::Result<PagedAttentionInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) {
if (q_desc->ndim() != 3 || out_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->ndim() != 2 || cache_lens_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->dtype() != INFINI_DTYPE_I64) {
CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
const auto block_tables_dt = block_tables_desc->dtype();
const auto cache_lens_dt = cache_lens_desc->dtype();
const bool debug_dtype = (std::getenv("INFINIOP_FLASH_DEBUG_DTYPE") != nullptr);
const bool block_tables_ok = (block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32);
const bool cache_lens_ok = (cache_lens_dt == INFINI_DTYPE_I64) || (cache_lens_dt == INFINI_DTYPE_I32) || (cache_lens_dt == INFINI_DTYPE_U32);
if (!(block_tables_ok && cache_lens_ok)) {
if (debug_dtype) {
std::fprintf(stderr,
"[flash_attention] Bad index dtype: block_tables=%d cache_lens=%d (expected I32/I64/U32)\n",
static_cast<int>(block_tables_dt), static_cast<int>(cache_lens_dt));
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
if (block_tables_dt != cache_lens_dt) {
// Keep them consistent to simplify backend dispatch.
if (debug_dtype) {
std::fprintf(stderr,
"[flash_attention] Mismatched index dtype: block_tables=%d cache_lens=%d\n",
static_cast<int>(block_tables_dt), static_cast<int>(cache_lens_dt));
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// --- Extract shape dimensions ---
CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(cache_lens_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (alibi_slopes_desc.value()->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
}
// Shapes
auto q_shape = q_desc->shape();
auto k_cache_shape = k_cache_desc->shape();
auto k_shape = k_cache_desc->shape();
const size_t num_seqs = q_shape[0];
const size_t num_heads = q_shape[1];
const size_t head_size = q_shape[2];
const size_t num_blocks = k_shape[0];
(void)num_blocks;
const size_t page_block_size = k_shape[2];
const size_t num_kv_heads = k_shape[1];
// if (page_block_size % 256 != 0) {
// printf("paged block size %zu\n", page_block_size);
// return INFINI_STATUS_BAD_TENSOR_SHAPE;
// }
if (head_size != 64 && head_size != 128) {
// First build only targets common FA2 head dims (expand later).
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (num_heads % num_kv_heads != 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[1] != k_shape[1] || v_cache_desc->shape()[2] != k_shape[2] || v_cache_desc->shape()[3] != k_shape[3]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t num_seqs = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) {
std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got "
<< head_size << "." << std::endl;
if (cache_lens_desc->shape()[0] != num_seqs) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
// Strides (in elements)
const ptrdiff_t q_stride = q_desc->stride(0);
const ptrdiff_t o_stride = out_desc->stride(0);
const ptrdiff_t k_batch_stride = k_cache_desc->stride(0);
const ptrdiff_t k_row_stride = k_cache_desc->stride(2);
const ptrdiff_t k_head_stride = k_cache_desc->stride(1);
const ptrdiff_t v_batch_stride = v_cache_desc->stride(0);
const ptrdiff_t v_row_stride = v_cache_desc->stride(2);
const ptrdiff_t v_head_stride = v_cache_desc->stride(1);
// --- Calculate max_seq_len for shared memory allocation ---
// This is a safe upper bound.
// info.max_seq_len = info.max_num_blocks_per_seq * info.block_size;
// --- Extract strides for memory access ---
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0);
const ptrdiff_t cache_lens_stride = cache_lens_desc->stride(0);
return utils::Result<PagedAttentionInfo>(PagedAttentionInfo{
dtype,
block_tables_dt,
scale,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
page_block_size,
max_num_blocks_per_seq,
q_stride,
kv_block_stride,
kv_head_stride,
o_stride});
k_batch_stride,
k_row_stride,
k_head_stride,
v_batch_stride,
v_row_stride,
v_head_stride,
o_stride,
block_table_batch_stride,
cache_lens_stride,
});
}
};
......
This diff is collapsed.
......@@ -3,6 +3,7 @@
#include "../../../utils.h"
#include "../../tensor.h"
#include <cstring>
#include <iostream>
#include <optional>
#include <vector>
......@@ -14,21 +15,30 @@ class PagedAttentionPrefillInfo {
public:
infiniDtype_t dtype;
infiniDtype_t index_dtype;
float scale;
size_t num_seqs;
size_t total_q_tokens;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t page_block_size;
size_t max_num_blocks_per_seq;
size_t total_q_tokens;
size_t num_blocks;
ptrdiff_t q_stride;
ptrdiff_t q_head_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t k_batch_stride;
ptrdiff_t k_row_stride;
ptrdiff_t k_head_stride;
ptrdiff_t v_batch_stride;
ptrdiff_t v_row_stride;
ptrdiff_t v_head_stride;
ptrdiff_t o_stride;
ptrdiff_t o_head_stride;
ptrdiff_t block_table_batch_stride;
static utils::Result<PagedAttentionPrefillInfo> create(
infiniopTensorDescriptor_t out_desc,
......@@ -36,89 +46,161 @@ public:
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
// q/out: [total_q, heads, head_dim]
if (q_desc->ndim() != 3 || out_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// FA2 paged KV layout: [num_blocks, page_block_size, kv_heads, head_dim]
if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->ndim() != 2 || total_kv_lens_desc->ndim() != 1 || cum_seqlens_q_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// Index dtypes: allow I32/I64/U32 (v0.4 roadmap allows internal conversion to I32).
const auto block_tables_dt = block_tables_desc->dtype();
if (!((block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32))) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now
// (matches current paged_attention_prefill signature). We will convert to int32 internally later.
if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (alibi_slopes_desc.value()->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
}
auto k_shape = k_cache_desc->shape();
auto v_shape = v_cache_desc->shape();
auto block_tables_shape = block_tables_desc->shape();
auto seq_lens_shape = seq_lens_desc->shape();
auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape();
const auto q_shape = q_desc->shape();
const auto k_shape = k_cache_desc->shape();
const size_t total_q_tokens = q_shape[0];
const size_t num_heads = q_shape[1];
const size_t head_size = q_shape[2];
const size_t num_blocks = k_shape[0];
const size_t page_block_size = k_shape[2];
const size_t num_kv_heads = k_shape[1];
if (k_shape.size() != 4 || v_shape.size() != 4) {
if (head_size != 64 && head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (num_heads % num_kv_heads != 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_shape.size() != 2) {
// v_cache must match the inferred K layout.
const auto v_shape = v_cache_desc->shape();
if (v_shape[0] != num_blocks || v_shape[3] != head_size) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
if (v_shape[1] != num_kv_heads || v_shape[2] != page_block_size) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) {
return INFINI_STATUS_BAD_PARAM;
if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[3] != k_shape[3]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Q shape: [total_tokens, heads, dim]
auto q_shape = q_desc->shape();
if (q_shape.size() != 3) {
if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t total_q_tokens = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
if (head_size > 1024) {
const size_t num_seqs = total_kv_lens_desc->shape()[0];
if (cum_seqlens_q_desc->shape()[0] != num_seqs + 1) {
return INFINI_STATUS_BAD_PARAM;
}
size_t num_seqs = seq_lens_shape[0];
size_t num_kv_heads = k_shape[1];
size_t block_size = k_shape[2];
size_t max_num_blocks_per_seq = block_tables_shape[1];
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t q_head_stride = q_desc->stride(1);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
// Strides (in elements)
const ptrdiff_t q_stride = q_desc->stride(0);
const ptrdiff_t q_head_stride = q_desc->stride(1);
const ptrdiff_t o_stride = out_desc->stride(0);
const ptrdiff_t o_head_stride = out_desc->stride(1);
const ptrdiff_t k_batch_stride = k_cache_desc->stride(0);
const ptrdiff_t k_row_stride = k_cache_desc->stride(2);
const ptrdiff_t k_head_stride = k_cache_desc->stride(1);
const ptrdiff_t v_batch_stride = v_cache_desc->stride(0);
const ptrdiff_t v_row_stride = v_cache_desc->stride(2);
const ptrdiff_t v_head_stride = v_cache_desc->stride(1);
const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0);
if (const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_INFO")) {
static bool printed = false;
if (!printed && std::strcmp(dbg, "1") == 0) {
const auto bt_shape = block_tables_desc->shape();
std::fprintf(stderr,
"[infiniop][flash_attention_prefill][info] k_shape=[%zu,%zu,%zu,%zu] k_strides=[%td,%td,%td,%td] (row_stride=%td head_stride=%td)\n",
static_cast<size_t>(k_shape[0]), static_cast<size_t>(k_shape[1]),
static_cast<size_t>(k_shape[2]), static_cast<size_t>(k_shape[3]),
k_cache_desc->stride(0), k_cache_desc->stride(1), k_cache_desc->stride(2), k_cache_desc->stride(3),
k_row_stride, k_head_stride);
std::fprintf(stderr,
"[infiniop][flash_attention_prefill][info] block_tables shape=[%zu,%zu] strides=[%td,%td]\n",
static_cast<size_t>(bt_shape[0]), static_cast<size_t>(bt_shape[1]),
block_tables_desc->stride(0), block_tables_desc->stride(1));
printed = true;
}
}
return utils::Result<PagedAttentionPrefillInfo>(PagedAttentionPrefillInfo{
dtype,
block_tables_dt,
scale,
num_seqs,
total_q_tokens,
num_heads,
num_kv_heads,
head_size,
block_size,
page_block_size,
max_num_blocks_per_seq,
total_q_tokens,
num_blocks,
q_stride,
q_head_stride,
kv_block_stride,
kv_head_stride,
o_stride});
k_batch_stride,
k_row_stride,
k_head_stride,
v_batch_stride,
v_row_stride,
v_head_stride,
o_stride,
o_head_stride,
block_table_batch_stride,
});
}
};
} // namespace op::paged_attention_prefill
#endif
......@@ -100,13 +100,12 @@ _TEST_CASES_ = [
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
}
# Global flags for controlling test behavior
......
......@@ -32,10 +32,9 @@ _TEST_CASES = [
(16, 128, 128, 128, 8, 16, 4),
]
_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
_TOLERANCE_MAP = {
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2},
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment