Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
298feac2
Unverified
Commit
298feac2
authored
Dec 29, 2025
by
PanZezhong1725
Committed by
GitHub
Dec 29, 2025
Browse files
Merge pull request #836 from InfiniTensor/issue/834
issue/834: add paged attention for nvidia gpu
parents
27777ee1
17299923
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1836 additions
and
0 deletions
+1836
-0
include/infiniop.h
include/infiniop.h
+2
-0
include/infiniop/ops/paged_attention.h
include/infiniop/ops/paged_attention.h
+93
-0
include/infiniop/ops/paged_caching.h
include/infiniop/ops/paged_caching.h
+77
-0
src/infiniop/ops/paged_attention/cuda/kernel.cuh
src/infiniop/ops/paged_attention/cuda/kernel.cuh
+149
-0
src/infiniop/ops/paged_attention/info.h
src/infiniop/ops/paged_attention/info.h
+109
-0
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
...niop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
+135
-0
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh
...iop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh
+8
-0
src/infiniop/ops/paged_attention/operator.cc
src/infiniop/ops/paged_attention/operator.cc
+105
-0
src/infiniop/ops/paged_attention/paged_attention.h
src/infiniop/ops/paged_attention/paged_attention.h
+53
-0
src/infiniop/ops/paged_caching/cuda/kernel.cuh
src/infiniop/ops/paged_caching/cuda/kernel.cuh
+88
-0
src/infiniop/ops/paged_caching/info.h
src/infiniop/ops/paged_caching/info.h
+82
-0
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+163
-0
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh
...nfiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh
+8
-0
src/infiniop/ops/paged_caching/operator.cc
src/infiniop/ops/paged_caching/operator.cc
+100
-0
src/infiniop/ops/paged_caching/paged_caching.h
src/infiniop/ops/paged_caching/paged_caching.h
+50
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+84
-0
test/infiniop/paged_attention.py
test/infiniop/paged_attention.py
+279
-0
test/infiniop/paged_caching.py
test/infiniop/paged_caching.py
+251
-0
No files found.
include/infiniop.h
View file @
298feac2
...
...
@@ -31,5 +31,7 @@
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_caching.h"
#endif // __INFINIOP_API_H__
include/infiniop/ops/paged_attention.h
0 → 100644
View file @
298feac2
#ifndef __INFINIOP_PAGED_ATTENTION_API_H__
#define __INFINIOP_PAGED_ATTENTION_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Attention descriptor.
typedef
struct
InfiniopDescriptor
*
infiniopPagedAttentionDescriptor_t
;
/**
* @brief Creates a descriptor for the Paged Attention v1 operation.
*
* @param handle The library context handle.
* @param desc_ptr Pointer to the created descriptor.
* @param out_desc [Output] Shape: (num_seqs, num_heads, head_size).
* The output tensor for the attention mechanism.
* @param q_desc [Input] Shape: (num_seqs, num_heads, head_size).
* The query tensor.
* @param k_cache_desc [Input] Shape: (num_blocks, num_kv_heads, block_size, head_size).
* Paged key cache storing keys for all sequences.
* @param v_cache_desc [Input] Shape: (num_blocks, num_kv_heads, block_size, head_size).
* Paged value cache storing values for all sequences.
* @param block_tables_desc [Input] Shape: (num_seqs, max_num_blocks_per_seq).
* Maps each sequence to its physical block indices in the cache.
* Expected DType: int64_t (I64).
* @param seq_lens_desc [Input] Shape: (num_seqs,).
* The current logical length of each sequence.
* Expected DType: int64_t (I64).
* @param alibi_slopes_desc [Optional] Shape: (num_heads,).
* Slopes for ALiBi (Attention with Linear Biases). Can be NULL.
* @param scale The attention scaling factor (typically 1/sqrt(head_size)).
* @return infiniStatus_t Status code.
*/
__C
__export
infiniStatus_t
infiniopCreatePagedAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
);
/**
* @brief Retrieves the workspace size required for the Paged Attention operation.
*
* @param desc The Paged Attention descriptor.
* @param size A pointer to store the required workspace size in bytes.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopGetPagedAttentionWorkspaceSize
(
infiniopPagedAttentionDescriptor_t
desc
,
size_t
*
size
);
/**
* @brief Executes the Paged Attention v1 operation.
*
* @param desc The Paged Attention descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data.
* @param k_cache Pointer to the key cache data.
* @param v_cache Pointer to the value cache data.
* @param block_tables Pointer to the block tables data.
* @param seq_lens Pointer to the sequence lengths data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopPagedAttention
(
infiniopPagedAttentionDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
void
*
stream
);
/**
* @brief Destroys a Paged Attention descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopDestroyPagedAttentionDescriptor
(
infiniopPagedAttentionDescriptor_t
desc
);
#endif // __INFINIOP_PAGED_ATTENTION_API_H__
include/infiniop/ops/paged_caching.h
0 → 100644
View file @
298feac2
#ifndef __INFINIOP_PAGED_CACHING_API_H__
#define __INFINIOP_PAGED_CACHING_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Caching descriptor.
typedef
struct
InfiniopDescriptor
*
infiniopPagedCachingDescriptor_t
;
/**
* @brief Creates a descriptor for the Paged Caching operation.
*
* This function initializes a descriptor that holds all the metadata needed
* to copy key/value vectors into their respective cache pools.
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
);
/**
* @brief Retrieves the workspace size required for the Paged Caching operation.
*
* @param desc The Paged Caching descriptor.
* @param size A pointer to store the required workspace size in bytes (typically 0).
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopGetPagedCachingWorkspaceSize
(
infiniopPagedCachingDescriptor_t
desc
,
size_t
*
size
);
/**
* @brief Executes the Paged Caching operation.
*
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopPagedCaching
(
infiniopPagedCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
const
void
*
k
,
const
void
*
v
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
slot_mapping
,
void
*
stream
);
/**
* @brief Destroys a Paged Caching descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopDestroyPagedCachingDescriptor
(
infiniopPagedCachingDescriptor_t
desc
);
#endif // __INFINIOP_PAGED_CACHING_API_H__
src/infiniop/ops/paged_attention/cuda/kernel.cuh
0 → 100644
View file @
298feac2
#ifndef __PAGED_ATTENTION_KERNEL_CUH__
#define __PAGED_ATTENTION_KERNEL_CUH__
// This kernel is refactored to be high-performance, adopting parallelism strategies
// from industry-standard implementations like vLLM. It fixes functional and performance
// issues in the original draft.
namespace
op
::
paged_attention
::
cuda
{
template
<
typename
Tdata
,
typename
Tcompute
,
size_t
HEAD_SIZE
,
size_t
NUM_THREADS
>
__device__
void
pagedAttentionKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
seq_lens_
,
const
float
*
alibi_slopes_
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o_stride
)
{
//================================================================================
// 1. Setup & Query Loading (No changes in this section)
//================================================================================
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int64_t
seq_len
=
seq_lens_
[
seq_idx
];
if
(
seq_len
==
0
)
{
return
;
}
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
head_idx
*
HEAD_SIZE
;
extern
__shared__
char
shared_mem_char
[];
Tcompute
*
shared_mem
=
reinterpret_cast
<
Tcompute
*>
(
shared_mem_char
);
Tcompute
*
q_shared
=
shared_mem
;
Tcompute
*
logits
=
shared_mem
+
HEAD_SIZE
;
// printf("static_cast<Tcompute>(q_ptr[i]);");
for
(
size_t
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
q_shared
[
i
]
=
static_cast
<
Tcompute
>
(
q_ptr
[
i
]);
}
__syncthreads
();
//================================================================================
// 2. Compute QK Dot Product & Find Max Logit
//================================================================================
for
(
size_t
token_idx
=
threadIdx
.
x
;
token_idx
<
seq_len
;
token_idx
+=
NUM_THREADS
)
{
const
int64_t
block_idx
=
token_idx
/
block_size
;
const
int64_t
token_in_block_idx
=
token_idx
%
block_size
;
const
int64_t
physical_block_num
=
block_table
[
block_idx
];
const
Tdata
*
k_vec_ptr
=
k_cache_
+
physical_block_num
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
token_in_block_idx
*
HEAD_SIZE
;
Tcompute
qk
=
0.0
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
HEAD_SIZE
/
8
;
++
i
)
{
const
size_t
offset
=
i
*
8
;
// 手动展开8次计算
qk
+=
q_shared
[
offset
+
0
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
0
]);
qk
+=
q_shared
[
offset
+
1
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
1
]);
qk
+=
q_shared
[
offset
+
2
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
2
]);
qk
+=
q_shared
[
offset
+
3
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
3
]);
qk
+=
q_shared
[
offset
+
4
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
4
]);
qk
+=
q_shared
[
offset
+
5
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
5
]);
qk
+=
q_shared
[
offset
+
6
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
6
]);
qk
+=
q_shared
[
offset
+
7
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
7
]);
}
qk
*=
scale
;
if
(
alibi_slope
!=
0.0
f
)
{
qk
+=
alibi_slope
*
(
token_idx
-
seq_len
+
1
);
}
logits
[
token_idx
]
=
qk
;
}
__syncthreads
();
__shared__
Tcompute
global_qk_max
;
Tcompute
global_qk_max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
NUM_THREADS
,
Tcompute
>
(
logits
,
seq_len
);
if
(
threadIdx
.
x
==
0
)
{
global_qk_max
=
global_qk_max_0
;
}
__syncthreads
();
//================================================================================
// 3. Compute Softmax (No changes in this section)
//================================================================================
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_len
;
i
+=
NUM_THREADS
)
{
Tcompute
val
=
expf
(
logits
[
i
]
-
global_qk_max
);
// 使用全局最大值
logits
[
i
]
=
val
;
}
__syncthreads
();
__shared__
Tcompute
inv_sum
;
Tcompute
exp_sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
NUM_THREADS
,
Tcompute
,
Tcompute
>
(
logits
,
seq_len
);
if
(
threadIdx
.
x
==
0
)
{
inv_sum
=
1.0
f
/
(
exp_sum_0
+
1e-6
f
);
}
__syncthreads
();
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_len
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
}
__syncthreads
();
//================================================================================
// 4. Aggregate Values (V) weighted by probabilities
//================================================================================
for
(
size_t
h_dim
=
threadIdx
.
x
;
h_dim
<
HEAD_SIZE
;
h_dim
+=
NUM_THREADS
)
{
Tcompute
acc
=
0.0
f
;
for
(
size_t
token_idx
=
0
;
token_idx
<
seq_len
;
++
token_idx
)
{
const
size_t
block_idx
=
token_idx
/
block_size
;
const
size_t
token_in_block_idx
=
token_idx
%
block_size
;
const
int64_t
physical_block_num
=
block_table
[
block_idx
];
const
Tcompute
prob
=
logits
[
token_idx
];
const
Tdata
*
v_vec_ptr
=
v_cache_
+
physical_block_num
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
token_in_block_idx
*
HEAD_SIZE
;
const
Tdata
v_val
=
v_vec_ptr
[
h_dim
];
acc
+=
prob
*
static_cast
<
Tcompute
>
(
v_val
);
}
out_ptr
[
h_dim
]
=
static_cast
<
Tdata
>
(
acc
);
}
}
}
// namespace op::paged_attention::cuda
#endif // __PAGED_ATTENTION_KERNEL_CUH__
src/infiniop/ops/paged_attention/info.h
0 → 100644
View file @
298feac2
#ifndef __PAGED_ATTENTION_INFO_H__
#define __PAGED_ATTENTION_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <iostream>
#include <optional>
#include <vector>
namespace
op
::
paged_attention
{
class
PagedAttentionInfo
{
PagedAttentionInfo
()
=
default
;
public:
// --- Data Types and Scale ---
infiniDtype_t
dtype
;
float
scale
;
// --- Shape Dimensions ---
size_t
num_seqs
;
size_t
num_heads
;
size_t
num_kv_heads
;
size_t
head_size
;
size_t
block_size
;
size_t
max_num_blocks_per_seq
;
// --- Strides for Memory Layout ---
ptrdiff_t
q_stride
;
ptrdiff_t
kv_block_stride
;
ptrdiff_t
kv_head_stride
;
ptrdiff_t
o_stride
;
static
utils
::
Result
<
PagedAttentionInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
dtype
=
q_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
q_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
<
4
||
v_cache_desc
->
ndim
()
<
4
||
block_tables_desc
->
ndim
()
!=
2
||
seq_lens_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
block_tables_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// --- Extract shape dimensions ---
auto
q_shape
=
q_desc
->
shape
();
auto
k_cache_shape
=
k_cache_desc
->
shape
();
size_t
num_seqs
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
1
];
size_t
head_size
=
q_shape
[
2
];
if
(
head_size
!=
128
)
{
// 输出具体的错误原因和当前的参数值
std
::
cerr
<<
"[Error] Now only supports head_size = 128, but got "
<<
head_size
<<
"."
<<
std
::
endl
;
// 建议返回 SHAPE 相关的错误码
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
num_kv_heads
=
k_cache_shape
[
1
];
size_t
block_size
=
v_cache_desc
->
shape
()[
2
];
// 使用V cache的block size维度更可靠
size_t
max_num_blocks_per_seq
=
block_tables_desc
->
shape
()[
1
];
// --- Calculate max_seq_len for shared memory allocation ---
// This is a safe upper bound.
// info.max_seq_len = info.max_num_blocks_per_seq * info.block_size;
// --- Extract strides for memory access ---
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
return
utils
::
Result
<
PagedAttentionInfo
>
(
PagedAttentionInfo
{
dtype
,
scale
,
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_num_blocks_per_seq
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
});
}
};
}
// namespace op::paged_attention
#endif // __PAGED_ATTENTION_INFO_H__
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
0 → 100644
View file @
298feac2
#include <cub/block/block_reduce.cuh>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_nvidia.cuh"
template
<
typename
Tdata
,
typename
Tcompute
,
size_t
HEAD_SIZE
,
size_t
NUM_THREADS
>
INFINIOP_CUDA_KERNEL
pagedAttention
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
seq_lens
,
const
float
*
alibi_slopes
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o_stride
)
{
op
::
paged_attention
::
cuda
::
pagedAttentionKernel
<
Tdata
,
Tcompute
,
HEAD_SIZE
,
NUM_THREADS
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
namespace
op
::
paged_attention
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
seq_lens_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
size_t
HEAD_SIZE
,
size_t
NUM_THREADS
>
infiniStatus_t
launchKernel
(
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
infiniDtype_t
dtype
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
block_size
,
ptrdiff_t
q_stride
,
ptrdiff_t
kv_block_stride
,
ptrdiff_t
kv_head_stride
,
ptrdiff_t
o_stride
,
cudaStream_t
stream
)
{
dim3
grid
(
uint64_t
(
num_heads
),
uint64_t
(
num_seqs
),
1
);
dim3
block
(
NUM_THREADS
);
size_t
shared_mem_size
=
(
HEAD_SIZE
+
max_num_blocks_per_seq
*
block_size
+
2
)
*
sizeof
(
float
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
pagedAttention
<
half
,
float
,
HEAD_SIZE
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
half
*
)
out
,
(
const
half
*
)
q
,
(
const
half
*
)
k_cache
,
(
const
half
*
)
v_cache
,
(
const
int64_t
*
)
block_tables
,
(
const
int64_t
*
)
seq_lens
,
(
const
float
*
)
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
pagedAttention
<
__nv_bfloat16
,
float
,
HEAD_SIZE
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
out
,
(
const
__nv_bfloat16
*
)
q
,
(
const
__nv_bfloat16
*
)
k_cache
,
(
const
__nv_bfloat16
*
)
v_cache
,
(
const
int64_t
*
)
block_tables
,
(
const
int64_t
*
)
seq_lens
,
(
const
float
*
)
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedAttention
<
float
,
float
,
HEAD_SIZE
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
float
*
)
out
,
(
const
float
*
)
q
,
(
const
float
*
)
k_cache
,
(
const
float
*
)
v_cache
,
(
const
int64_t
*
)
block_tables
,
(
const
int64_t
*
)
seq_lens
,
(
const
float
*
)
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
_info
.
head_size
==
128
)
{
launchKernel
<
128
,
CUDA_BLOCK_SIZE_1024
>
(
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_tables
,
seq_lens
,
alibi_slopes
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
block_size
,
_info
.
q_stride
,
_info
.
kv_block_stride
,
_info
.
kv_head_stride
,
_info
.
o_stride
,
stream
);
}
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
if
(
_info
.
head_size
==
128
)
{
launchKernel
<
128
,
CUDA_BLOCK_SIZE_512
>
(
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_tables
,
seq_lens
,
alibi_slopes
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
block_size
,
_info
.
q_stride
,
_info
.
kv_block_stride
,
_info
.
kv_head_stride
,
_info
.
o_stride
,
stream
);
}
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
if
(
_info
.
head_size
==
128
)
{
launchKernel
<
128
,
CUDA_BLOCK_SIZE_4096
>
(
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_tables
,
seq_lens
,
alibi_slopes
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
block_size
,
_info
.
q_stride
,
_info
.
kv_block_stride
,
_info
.
kv_head_stride
,
_info
.
o_stride
,
stream
);
}
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::paged_attention::nvidia
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh
0 → 100644
View file @
298feac2
#ifndef __PAGED_ATTENTION_NVIDIA_H__
#define __PAGED_ATTENTION_NVIDIA_H__
#include "../paged_attention.h"
DESCRIPTOR
(
nvidia
)
#endif // __PAGED_ATTENTION_NVIDIA_H__
src/infiniop/ops/paged_attention/operator.cc
0 → 100644
View file @
298feac2
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_attention.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
)
{
infiniopTensorDescriptor_t
alibi_opt
=
(
alibi_slopes_desc
==
nullptr
)
?
nullptr
:
alibi_slopes_desc
;
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_opt, scale);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetPagedAttentionWorkspaceSize
(
infiniopPagedAttentionDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopPagedAttention
(
infiniopPagedAttentionDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
seq_lens, alibi_slopes, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyPagedAttentionDescriptor
(
infiniopPagedAttentionDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/paged_attention/paged_attention.h
0 → 100644
View file @
298feac2
#ifndef PAGED_ATTENTION_H
#define PAGED_ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, const void *seq_lens, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_H
src/infiniop/ops/paged_caching/cuda/kernel.cuh
0 → 100644
View file @
298feac2
#ifndef __PAGED_CACHING_KERNEL_CUH__
#define __PAGED_CACHING_KERNEL_CUH__
//================================================================================
// Paged Caching Operator CUDA Kernel
//
// This kernel implements the "paged_caching" operation, which copies Key and Value
// vectors from a contiguous source tensor into a paged, non-contiguous KV Cache.
//
// Design Principles:
// 1. Token-Centric Parallelism: A 1D grid of `num_tokens` is launched. Each CUDA
// block is responsible for caching one full token (all its heads).
// 2. Coalesced Memory Access: This grid strategy ensures that threads within a
// block read a large, contiguous chunk of memory from the source tensors,
// maximizing memory bandwidth utilization.
// 3. Vectorization: The copy operation is vectorized to further enhance memory
// throughput, processing multiple data elements in a single instruction.
//================================================================================
namespace
op
::
paged_caching
::
cuda
{
template
<
typename
Tdata
,
// Data type of the tensors (e.g., half, __nv_bfloat16)
int
NUM_THREADS
// Number of threads per block, configured at launch time
>
__device__
void
pagedCachingKernel
(
// ----- Output Tensors -----
Tdata
*
k_cache_ptr
,
// Pointer to the destination K cache pool [num_blocks, nkvh, block_size, dh]
Tdata
*
v_cache_ptr
,
// Pointer to the destination V cache pool [num_blocks, nkvh, block_size, dh]
// ----- Input Tensors -----
const
Tdata
*
k_ptr
,
// Pointer to the source Keys, shape [ntok, nkvh, dh]
const
Tdata
*
v_ptr
,
// Pointer to the source Values, shape [ntok, nkvh, dh]
const
int64_t
*
slot_mapping_ptr
,
// Pointer to the slot mapping, shape [ntok]
// ----- Metadata -----
const
size_t
head_size
,
// Dimension of each head (dh)
const
size_t
block_size
,
// Number of tokens per block in the KV cache
// ----- Stride Information -----
const
ptrdiff_t
k_src_stride
,
// Stride between tokens in the source K tensor
const
ptrdiff_t
v_src_stride
,
// Stride between tokens in the source V tensor
const
ptrdiff_t
k_cache_block_stride
,
// Stride between blocks in the K cache pool
const
ptrdiff_t
v_cache_block_stride
// Stride between blocks in the V cache pool
)
{
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
//================================================================================
// Each block processes one token.
const
int
token_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
// const int num_kv_heads = gridDim.y;
// Retrieve the destination slot for the current token.
const
int64_t
slot_idx
=
slot_mapping_ptr
[
token_idx
];
// Handle padding: if slot_idx is negative, this token is padding and should be ignored.
if
(
slot_idx
<
0
)
{
return
;
}
// Calculate the physical block index and the offset within that block.
const
int64_t
physical_block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
// Calculate base pointers for source and destination for this specific token.
const
Tdata
*
k_src_head_ptr
=
k_ptr
+
token_idx
*
k_src_stride
+
head_idx
*
head_size
;
const
Tdata
*
v_src_head_ptr
=
v_ptr
+
token_idx
*
v_src_stride
+
head_idx
*
head_size
;
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const
ptrdiff_t
cache_head_stride
=
block_size
*
head_size
;
Tdata
*
k_cache_block_base_ptr
=
k_cache_ptr
+
physical_block_idx
*
k_cache_block_stride
;
Tdata
*
k_dst_head_ptr
=
k_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_size
;
Tdata
*
v_cache_block_base_ptr
=
v_cache_ptr
+
physical_block_idx
*
v_cache_block_stride
;
Tdata
*
v_dst_head_ptr
=
v_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_size
;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
//================================================================================
for
(
int
i
=
threadIdx
.
x
;
i
<
head_size
;
i
+=
NUM_THREADS
)
{
k_dst_head_ptr
[
i
]
=
k_src_head_ptr
[
i
];
v_dst_head_ptr
[
i
]
=
v_src_head_ptr
[
i
];
}
}
}
// namespace op::paged_caching::cuda
#endif // __PAGED_CACHING_KERNEL_CUH__
src/infiniop/ops/paged_caching/info.h
0 → 100644
View file @
298feac2
#ifndef __PAGED_CACHING_INFO_H__
#define __PAGED_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <optional>
#include <vector>
namespace
op
::
paged_caching
{
class
PagedCachingInfo
{
PagedCachingInfo
()
=
default
;
public:
// --- Data Type ---
infiniDtype_t
dtype
;
// --- Shape Dimensions ---
size_t
num_tokens
;
size_t
num_kv_heads
;
size_t
head_size
;
size_t
block_size
;
// --- Strides for Memory Layout ---
ptrdiff_t
k_src_stride
;
ptrdiff_t
v_src_stride
;
ptrdiff_t
k_cache_block_stride
;
ptrdiff_t
v_cache_block_stride
;
static
utils
::
Result
<
PagedCachingInfo
>
create
(
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
auto
dtype
=
k_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
if
(
v_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
slot_mapping_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
printf
(
"slot_mapping must be int64_t.
\n
"
);
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
k_desc
->
ndim
()
!=
3
||
v_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
<
4
||
v_cache_desc
->
ndim
()
<
4
||
slot_mapping_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// PagedCachingInfo info;
// --- Extract shape dimensions ---
auto
k_shape
=
k_desc
->
shape
();
auto
k_cache_shape
=
k_cache_desc
->
shape
();
size_t
num_tokens
=
slot_mapping_desc
->
shape
()[
0
];
size_t
num_kv_heads
=
k_shape
[
1
];
size_t
head_size
=
k_shape
[
2
];
size_t
block_size
=
k_cache_shape
[
2
];
// Assuming [num_blocks, num_heads, block_size, head_size] layout
// --- Extract strides for memory access ---
ptrdiff_t
k_src_stride
=
k_desc
->
stride
(
0
);
ptrdiff_t
v_src_stride
=
v_desc
->
stride
(
0
);
ptrdiff_t
k_cache_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
v_cache_block_stride
=
v_cache_desc
->
stride
(
0
);
return
utils
::
Result
<
PagedCachingInfo
>
(
PagedCachingInfo
{
dtype
,
num_tokens
,
num_kv_heads
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
});
}
};
}
// namespace op::paged_caching
#endif // __PAGED_CACHING_INFO_H__
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
0 → 100644
View file @
298feac2
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "paged_caching_nvidia.cuh"
template
<
typename
Tdata
,
int
NUM_THREADS
>
INFINIOP_CUDA_KERNEL
pagedCaching
(
Tdata
*
k_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
int64_t
*
slot_mapping
,
const
size_t
head_size
,
const
size_t
block_size
,
const
ptrdiff_t
k_src_stride
,
const
ptrdiff_t
v_src_stride
,
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
)
{
op
::
paged_caching
::
cuda
::
pagedCachingKernel
<
Tdata
,
NUM_THREADS
>
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
namespace
op
::
paged_caching
::
nvidia
{
// PIMPL struct definition
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
// Destructor implementation
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
// Static factory method implementation
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
auto
info
=
PagedCachingInfo
::
create
(
k_desc
,
v_desc
,
k_cache_desc
,
v_cache_desc
,
slot_mapping_desc
);
CHECK_RESULT
(
info
);
// Create and return the Descriptor instance.
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
// The launchKernel function is a templated helper to encapsulate the CUDA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template
<
int
NUM_THREADS
>
infiniStatus_t
launchKernel
(
const
PagedCachingInfo
&
info
,
void
*
k_cache
,
void
*
v_cache
,
infiniDtype_t
dtype
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
size_t
num_tokens
,
size_t
num_kv_heads
,
size_t
head_size
,
size_t
block_size
,
ptrdiff_t
k_src_stride
,
ptrdiff_t
v_src_stride
,
ptrdiff_t
k_cache_block_stride
,
ptrdiff_t
v_cache_block_stride
,
cudaStream_t
stream
)
{
// Grid dimension is 1D, with one block per token, as we decided.
dim3
grid
(
uint64_t
(
num_kv_heads
),
uint64_t
(
num_tokens
),
1
);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3
block
(
NUM_THREADS
);
// This kernel does not require dynamic shared memory.
size_t
shared_mem_size
=
0
;
// Launch the device-side CUDA kernel.
if
(
dtype
==
INFINI_DTYPE_F16
)
{
pagedCaching
<
half
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
half
*
)
k_cache
,
(
half
*
)
v_cache
,
(
const
half
*
)
k
,
(
const
half
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
k_cache
,
(
__nv_bfloat16
*
)
v_cache
,
(
const
__nv_bfloat16
*
)
k
,
(
const
__nv_bfloat16
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedCaching
<
float
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
float
*
)
k_cache
,
(
float
*
)
v_cache
,
(
const
float
*
)
k
,
(
const
float
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
// Execution method implementation
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
const
void
*
k
,
const
void
*
v
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
slot_mapping
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_1024
)
{
// Dispatch based on data type for a 1024-thread block.
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
_info
,
k_cache
,
v_cache
,
_info
.
dtype
,
k
,
v
,
slot_mapping
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_512
)
{
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
_info
,
k_cache
,
v_cache
,
_info
.
dtype
,
k
,
v
,
slot_mapping
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_4096
)
{
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
_info
,
k_cache
,
v_cache
,
_info
.
dtype
,
k
,
v
,
slot_mapping
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
stream
);
}
else
{
// If the GPU is older and supports fewer threads, return an error.
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::paged_caching::nvidia
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh
0 → 100644
View file @
298feac2
#ifndef __PAGED_CACHING_NVIDIA_H__
#define __PAGED_CACHING_NVIDIA_H__
#include "../paged_caching.h"
DESCRIPTOR
(
nvidia
)
#endif // __PAGED_CACHING_NVIDIA_H__
src/infiniop/ops/paged_caching/operator.cc
0 → 100644
View file @
298feac2
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
__C
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetPagedCachingWorkspaceSize
(
infiniopPagedCachingDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopPagedCaching
(
infiniopPagedCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
const
void
*
k
,
const
void
*
v
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
slot_mapping
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, k, v, k_cache, v_cache, slot_mapping, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyPagedCachingDescriptor
(
infiniopPagedCachingDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/paged_caching/paged_caching.h
0 → 100644
View file @
298feac2
#ifndef PAGED_CACHING_H
#define PAGED_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
const void *k, const void *v, \
void *k_cache, void *v_cache, \
const void *slot_mapping, \
void *stream) const; \
}; \
}
#endif // PAGED_CACHING_H
test/infiniop/libinfiniop/op_register.py
View file @
298feac2
...
...
@@ -977,3 +977,87 @@ def scaled_mm_int8_(lib):
lib
.
infiniopDestroyI8GemmDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
paged_attention_
(
lib
):
lib
.
infiniopCreatePagedAttentionDescriptor
.
restype
=
c_int32
lib
.
infiniopCreatePagedAttentionDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
c_void_p
,
c_float
,
]
lib
.
infiniopGetPagedAttentionWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetPagedAttentionWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopPagedAttention
.
restype
=
c_int32
lib
.
infiniopPagedAttention
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyPagedAttentionDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyPagedAttentionDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
paged_caching_
(
lib
):
lib
.
infiniopCreatePagedCachingDescriptor
.
restype
=
c_int32
lib
.
infiniopCreatePagedCachingDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
# k_desc
infiniopTensorDescriptor_t
,
# v_desc
infiniopTensorDescriptor_t
,
# k_cache_desc
infiniopTensorDescriptor_t
,
# v_cache_desc
infiniopTensorDescriptor_t
,
# slot_mapping_desc
]
# infiniopGetPagedCachingWorkspaceSize
lib
.
infiniopGetPagedCachingWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetPagedCachingWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
# infiniopPagedCaching
lib
.
infiniopPagedCaching
.
restype
=
c_int32
lib
.
infiniopPagedCaching
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
# workspace
c_size_t
,
# workspace_size
c_void_p
,
# k
c_void_p
,
# v
c_void_p
,
# k_cache
c_void_p
,
# v_cache
c_void_p
,
# slot_mapping
c_void_p
,
# stream
]
# infiniopDestroyPagedCachingDescriptor
lib
.
infiniopDestroyPagedCachingDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyPagedCachingDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
test/infiniop/paged_attention.py
0 → 100644
View file @
298feac2
import
torch
import
ctypes
from
ctypes
import
c_uint64
import
math
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
get_alibi_slopes
(
n
):
# 简化版的ALiBi斜率计算方法
# 参考: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L742
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
base
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
closest_power_of_2
)
-
3
)))
powers
=
[
base
**
i
for
i
in
range
(
1
,
closest_power_of_2
+
1
)]
if
n
>
closest_power_of_2
:
extra
=
[
base
**
(
i
*
2
)
for
i
in
range
(
1
,
2
*
(
n
-
closest_power_of_2
)
+
1
,
2
)]
powers
+=
extra
return
powers
[:
n
]
def
ref_masked_attention
(
query
,
key
,
value
,
scale
,
attn_mask
=
None
):
# Reference implementation for a single masked attention head.
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
def
ref_single_query_cached_kv_attention
(
query
,
key_cache
,
value_cache
,
block_tables
,
seq_lens
,
scale
,
alibi_slopes
):
# Reference implementation for paged attention, iterating through each sequence.
output
=
torch
.
empty_like
(
query
)
num_query_heads
,
num_kv_heads
=
query
.
shape
[
1
],
value_cache
.
shape
[
1
]
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_size
,
block_size
=
value_cache
.
shape
[
3
],
value_cache
.
shape
[
2
]
num_seqs
=
query
.
shape
[
0
]
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
seq_len
=
seq_lens
[
i
].
item
()
block_table
=
block_tables
[
i
]
keys_lst
,
values_lst
=
[],
[]
for
j
in
range
(
seq_len
):
block_num
=
block_table
[
j
//
block_size
].
item
()
block_off
=
j
%
block_size
k
=
key_cache
[
block_num
,
:,
block_off
,
:]
v
=
value_cache
[
block_num
,
:,
block_off
,
:]
keys_lst
.
append
(
k
)
values_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
values
=
torch
.
repeat_interleave
(
values
,
num_queries_per_kv
,
dim
=
1
)
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
pos
=
torch
.
arange
(
seq_len
,
device
=
query
.
device
).
int
()
alibi_bias
=
(
pos
-
seq_len
+
1
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
,
alibi_bias
)
output
[
i
]
=
out
.
view
(
num_query_heads
,
head_size
)
return
output
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(
1
,
1
,
1
,
128
,
16
,
1024
,
False
),
(
4
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
6
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
128
,
16
,
1024
,
False
),
(
8
,
64
,
8
,
128
,
16
,
2048
,
False
),
]
# Data types for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
,
InfiniDtype
.
F32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
F32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
}
# Global flags for controlling test behavior
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
def
test
(
handle
,
device
,
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_seq_len
,
use_alibi
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing PagedAttention on
{
InfiniDeviceNames
[
device
]
}
with "
f
"num_seqs=
{
num_seqs
}
, num_heads=
{
num_heads
}
, head_size=
{
head_size
}
, "
f
"block_size=
{
block_size
}
, dtype=
{
InfiniDtypeNames
[
dtype
]
}
, use_alibi=
{
use_alibi
}
"
)
scale
=
1.0
/
(
head_size
**
0.5
)
max_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
num_blocks
=
num_seqs
*
max_blocks_per_seq
# A reasonable number for testing
# Create input tensors
q
=
TestTensor
((
num_seqs
,
num_heads
,
head_size
),
None
,
dtype
,
device
)
out
=
TestTensor
((
num_seqs
,
num_heads
,
head_size
),
None
,
dtype
,
device
)
k_cache
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
v_cache
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
seq_lens_direct
=
1023
seq_lens_torch
=
torch
.
randint
(
1
,
seq_lens_direct
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
seq_lens
=
TestTensor
.
from_torch
(
seq_lens_torch
,
InfiniDtype
.
I64
,
device
)
block_tables_py
=
torch
.
arange
(
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int64
).
view
(
num_seqs
,
max_blocks_per_seq
)
block_tables
=
TestTensor
.
from_torch
(
block_tables_py
,
InfiniDtype
.
I64
,
device
)
alibi_slopes_desc
=
ctypes
.
c_void_p
(
0
)
alibi_slopes_data
=
ctypes
.
c_void_p
(
0
)
alibi_slopes_torch
=
None
if
use_alibi
:
alibi_slopes
=
TestTensor
((
num_heads
,),
None
,
InfiniDtype
.
F32
,
device
)
alibi_slopes_desc
=
alibi_slopes
.
descriptor
alibi_slopes_data
=
alibi_slopes
.
data
()
alibi_slopes_torch
=
alibi_slopes
.
torch_tensor
()
# Run reference implementation
ans
=
ref_single_query_cached_kv_attention
(
q
.
torch_tensor
(),
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
block_tables
.
torch_tensor
(),
seq_lens
.
torch_tensor
(),
scale
,
alibi_slopes_torch
,
)
if
sync
:
sync
()
scale
=
1.0
/
(
head_size
**
0.5
)
# Create operator descriptor
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreatePagedAttentionDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
out
.
descriptor
,
q
.
descriptor
,
k_cache
.
descriptor
,
v_cache
.
descriptor
,
block_tables
.
descriptor
,
seq_lens
.
descriptor
,
alibi_slopes_desc
,
scale
,
)
)
# Get workspace size and allocate memory
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetPagedAttentionWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
q
.
device
)
# Invalidate descriptors to ensure kernel does not rely on them
q
.
destroy_desc
()
out
.
destroy_desc
()
k_cache
.
destroy_desc
()
v_cache
.
destroy_desc
()
block_tables
.
destroy_desc
()
seq_lens
.
destroy_desc
()
if
use_alibi
:
alibi_slopes
.
destroy_desc
()
# Define the library call as a lambda for profiling
def
lib_paged_attention
():
check_error
(
LIBINFINIOP
.
infiniopPagedAttention
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
out
.
data
(),
q
.
data
(),
k_cache
.
data
(),
v_cache
.
data
(),
block_tables
.
data
(),
seq_lens
.
data
(),
alibi_slopes_data
,
None
,
)
)
# Execute the custom operator
lib_paged_attention
()
if
sync
:
sync
()
# Verify correctness
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
ref_single_query_cached_kv_attention
(
q
.
torch_tensor
(),
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
block_tables
.
torch_tensor
(),
seq_lens
.
torch_tensor
(),
scale
,
alibi_slopes_torch
),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lib_paged_attention
,
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
# Clean up resources
check_error
(
LIBINFINIOP
.
infiniopDestroyPagedAttentionDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options from command line arguments
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES_
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/paged_caching.py
0 → 100644
View file @
298feac2
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
ref_paged_caching
(
key
,
value
,
key_cache_pool
,
value_cache_pool
,
slot_mapping
):
"""
Reference implementation for paged_caching operator.
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok
=
key
.
shape
[
0
]
block_size
=
key_cache_pool
.
shape
[
2
]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref
=
key_cache_pool
.
clone
()
v_cache_ref
=
value_cache_pool
.
clone
()
for
i
in
range
(
ntok
):
slot
=
slot_mapping
[
i
].
item
()
block_idx
=
slot
//
block_size
block_offset
=
slot
%
block_size
key_token
=
key
[
i
]
value_token
=
value
[
i
]
k_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
key_token
v_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
value_token
return
k_cache_ref
,
v_cache_ref
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_
=
[
# (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
(
1
,
128
,
8
,
128
,
16
),
(
5
,
512
,
40
,
128
,
16
),
(
16
,
1024
,
8
,
64
,
32
),
(
10
,
1024
,
40
,
64
,
32
),
]
# Data types for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
,
InfiniDtype
.
F32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
F32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
}
# Global flags for controlling test behavior
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
100
def
test
(
handle
,
device
,
num_seqs
,
# nreq
max_seq_len
,
num_kv_heads
,
# nkvh
head_size
,
# dh
block_size
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing PagedCaching on
{
InfiniDeviceNames
[
device
]
}
with "
f
"num_seqs=
{
num_seqs
}
, max_seq_len=
{
max_seq_len
}
, num_kv_heads=
{
num_kv_heads
}
, "
f
"head_size=
{
head_size
}
, block_size=
{
block_size
}
, dtype=
{
InfiniDtypeNames
[
dtype
]
}
"
)
num_blocks
=
4096
# A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
context_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
ntok
=
torch
.
sum
(
context_lens_torch
).
item
()
# If ntok is 0 (all sequences have length 0), skip the test
if
ntok
==
0
:
print
(
"Skipping test case with ntok=0"
)
return
# Simulate the scheduler's behavior to create the slot_mapping
slot_mapping_list
=
[]
current_slot
=
0
for
length
in
context_lens_torch
:
# Find a contiguous chunk of 'length' slots
start_slot
=
current_slot
slot_mapping_list
.
extend
(
range
(
start_slot
,
start_slot
+
length
.
item
()))
current_slot
+=
length
.
item
()
# Ensure we don't exceed the total number of slots in the cache
assert
(
current_slot
<=
num_blocks
*
block_size
),
"Not enough blocks in the cache pool for this test case"
slot_mapping_torch
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int64
)
# Create input tensors based on the calculated total tokens (ntok)
k
=
TestTensor
((
ntok
,
num_kv_heads
,
head_size
),
None
,
dtype
,
device
)
v
=
TestTensor
((
ntok
,
num_kv_heads
,
head_size
),
None
,
dtype
,
device
)
slot_mapping
=
TestTensor
.
from_torch
(
slot_mapping_torch
,
InfiniDtype
.
I64
,
device
)
# The cache pools are the "output" tensors for this operator
k_cache_pool
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
v_cache_pool
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
# Run reference implementation
k_cache_ref
,
v_cache_ref
=
ref_paged_caching
(
k
.
torch_tensor
(),
v
.
torch_tensor
(),
k_cache_pool
.
torch_tensor
(),
v_cache_pool
.
torch_tensor
(),
slot_mapping
.
torch_tensor
(),
)
if
sync
:
sync
()
# Create operator descriptor
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreatePagedCachingDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
k
.
descriptor
,
v
.
descriptor
,
k_cache_pool
.
descriptor
,
v_cache_pool
.
descriptor
,
slot_mapping
.
descriptor
,
)
)
# Get workspace size (likely 0 for this operator, but good practice to include)
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetPagedCachingWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
device
)
# Invalidate descriptors to ensure kernel does not rely on them
k
.
destroy_desc
()
v
.
destroy_desc
()
k_cache_pool
.
destroy_desc
()
v_cache_pool
.
destroy_desc
()
slot_mapping
.
destroy_desc
()
# Define the library call as a lambda for profiling
def
lib_paged_caching
():
check_error
(
LIBINFINIOP
.
infiniopPagedCaching
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
k
.
data
(),
v
.
data
(),
k_cache_pool
.
data
(),
v_cache_pool
.
data
(),
slot_mapping
.
data
(),
None
,
)
)
# Execute the custom operator
lib_paged_caching
()
if
sync
:
sync
()
# Verify correctness
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
print
(
"Verifying K cache..."
)
debug
(
k_cache_pool
.
actual_tensor
(),
k_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"Verifying V cache..."
)
debug
(
v_cache_pool
.
actual_tensor
(),
v_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
k_cache_pool
.
actual_tensor
(),
k_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
v_cache_pool
.
actual_tensor
(),
v_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
ref_paged_caching
(
k
.
torch_tensor
(),
v
.
torch_tensor
(),
k_cache_pool
.
torch_tensor
(),
v_cache_pool
.
torch_tensor
(),
slot_mapping
.
torch_tensor
()),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lib_paged_caching
,
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
# Clean up resources
check_error
(
LIBINFINIOP
.
infiniopDestroyPagedCachingDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options from command line arguments
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES_
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment