Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1ba0bcfa
Commit
1ba0bcfa
authored
Dec 30, 2025
by
zhushuang
Browse files
issue/848 - feat: add paged attention prefill for nvidia gpu with test pass
parent
298feac2
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1235 additions
and
2 deletions
+1235
-2
include/infiniop.h
include/infiniop.h
+3
-2
include/infiniop/ops/paged_attention_prefill.h
include/infiniop/ops/paged_attention_prefill.h
+83
-0
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
+134
-0
src/infiniop/ops/paged_attention_prefill/info.h
src/infiniop/ops/paged_attention_prefill/info.h
+107
-0
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+136
-0
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh
...tention_prefill/nvidia/paged_attention_prefill_nvidia.cuh
+8
-0
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+95
-0
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
...iop/ops/paged_attention_prefill/paged_attention_prefill.h
+56
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+48
-0
test/infiniop/paged_attention_prefill.py
test/infiniop/paged_attention_prefill.py
+315
-0
test/infiniop/paged_caching_prefill.py
test/infiniop/paged_caching_prefill.py
+250
-0
No files found.
include/infiniop.h
View file @
1ba0bcfa
...
...
@@ -15,6 +15,9 @@
#include "infiniop/ops/lp_norm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/ones.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
...
...
@@ -31,7 +34,5 @@
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_caching.h"
#endif // __INFINIOP_API_H__
include/infiniop/ops/paged_attention_prefill.h
0 → 100644
View file @
1ba0bcfa
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Attention Prefill descriptor.
typedef
struct
InfiniopDescriptor
*
infiniopPagedAttentionPrefillDescriptor_t
;
/**
* @brief Creates a descriptor for the Paged Attention Prefill operation.
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param q_desc Descriptor for the query tensor (packed/flattened).
* @param k_cache_desc Descriptor for the global physical key cache.
* @param v_cache_desc Descriptor for the global physical value cache.
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
* @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedAttentionPrefillDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
);
/**
* @brief Retrieves the workspace size required for the Paged Attention Prefill operation.
*/
__C
__export
infiniStatus_t
infiniopGetPagedAttentionPrefillWorkspaceSize
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
size_t
*
size
);
/**
* @brief Executes the Paged Attention Prefill operation.
* @param desc The Paged Attention Prefill descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data (packed).
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param cache_lens Pointer to the total sequence lengths data.
* @param seq_lens Pointer to the current prefill sequence lengths data.
* @param offset Pointer to the sequence start offsets data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA/device stream for the operation.
* @return infiniStatus_t Status code of the operation.
*/
__C
__export
infiniStatus_t
infiniopPagedAttentionPrefill
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
offset
,
const
void
*
alibi_slopes
,
void
*
stream
);
/**
* @brief Destroys a Paged Attention Prefill descriptor.
*/
__C
__export
infiniStatus_t
infiniopDestroyPagedAttentionPrefillDescriptor
(
infiniopPagedAttentionPrefillDescriptor_t
desc
);
#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
0 → 100644
View file @
1ba0bcfa
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace
op
::
paged_attention_prefill
::
cuda
{
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__
__forceinline__
int
find_seq_id
(
int
token_idx
,
const
int64_t
*
offset
,
int
num_seqs
)
{
int
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
int
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
offset
[
mid
]
&&
token_idx
<
offset
[
mid
+
1
])
{
return
mid
;
}
else
if
(
token_idx
<
offset
[
mid
])
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
}
}
return
0
;
}
template
<
typename
Tdata
,
typename
Tcompute
>
__global__
void
pagedAttentionPrefillKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
cache_lens_
,
const
int64_t
*
seq_lens_
,
const
float
*
alibi_slopes_
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
size_t
head_size
,
const
int64_t
*
offset_
,
const
size_t
num_seqs
)
{
// --- 使用 2D Grid 坐标 ---
const
int
global_token_idx
=
blockIdx
.
x
;
// 展平后的全局 token 索引
const
int
head_idx
=
blockIdx
.
y
;
// Head 索引
const
int
dim_idx
=
threadIdx
.
x
;
// Head 内部维度
if
(
dim_idx
>=
head_size
)
{
return
;
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
int
seq_idx
=
find_seq_id
(
global_token_idx
,
offset_
,
num_seqs
);
// --- 获取该 Sequence 本次 Prefill 的长度
const
int64_t
cur_new_len
=
seq_lens_
[
seq_idx
];
// --- 该 token 在当前序列中的相对位置
int
q_token_idx
=
global_token_idx
-
offset_
[
seq_idx
];
const
Tdata
*
q_ptr_base
=
q_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
// --- KV Cache 相关信息
const
int64_t
total_seq_len
=
cache_lens_
[
seq_idx
];
const
int64_t
history_len
=
total_seq_len
-
cur_new_len
;
const
int64_t
causal_limit
=
history_len
+
q_token_idx
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
// Pass 1: 计算 Score 并找最大值
Tcompute
max_score
=
-
FLT_MAX
;
for
(
int
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
int
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
);
}
if
(
score
>
max_score
)
{
max_score
=
score
;
}
}
// Pass 2: 计算 Sum of Exp
Tcompute
sum_exp
=
0.0
f
;
for
(
int
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
int
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
);
}
sum_exp
+=
expf
(
static_cast
<
float
>
(
score
-
max_score
));
}
// Pass 3: 加权求和得到输出
Tcompute
acc
=
0.0
f
;
Tcompute
inv_sum
=
1.0
f
/
(
sum_exp
+
1e-6
f
);
for
(
int
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
int64_t
b_idx
=
t
/
block_size
;
const
int64_t
t_off
=
t
%
block_size
;
const
int64_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
Tcompute
score
=
0.0
f
;
for
(
int
d
=
0
;
d
<
head_size
;
++
d
)
{
score
+=
static_cast
<
Tcompute
>
(
q_ptr_base
[
d
])
*
static_cast
<
Tcompute
>
(
k_vec
[
d
]);
}
score
*=
static_cast
<
Tcompute
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
);
}
Tcompute
prob
=
expf
(
static_cast
<
float
>
(
score
-
max_score
))
*
inv_sum
;
const
Tdata
*
v_vec
=
v_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
acc
+=
prob
*
static_cast
<
Tcompute
>
(
v_vec
[
dim_idx
]);
}
out_ptr
[
dim_idx
]
=
static_cast
<
Tdata
>
(
acc
);
}
}
// namespace op::paged_attention_prefill::cuda
#endif
src/infiniop/ops/paged_attention_prefill/info.h
0 → 100644
View file @
1ba0bcfa
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <iostream>
#include <optional>
#include <vector>
namespace
op
::
paged_attention_prefill
{
class
PagedAttentionPrefillInfo
{
PagedAttentionPrefillInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
float
scale
;
size_t
num_seqs
;
size_t
num_heads
;
size_t
num_kv_heads
;
size_t
head_size
;
size_t
block_size
;
size_t
max_num_blocks_per_seq
;
size_t
total_q_tokens
;
ptrdiff_t
q_stride
;
ptrdiff_t
kv_block_stride
;
ptrdiff_t
kv_head_stride
;
ptrdiff_t
o_stride
;
static
utils
::
Result
<
PagedAttentionPrefillInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
dtype
=
q_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
if
(
out_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
offset_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
seq_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
alibi_slopes_desc
.
has_value
()
&&
alibi_slopes_desc
.
value
()
!=
nullptr
)
{
std
::
cerr
<<
"[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet."
<<
std
::
endl
;
return
INFINI_STATUS_BAD_PARAM
;
}
// Q shape: [total_tokens, heads, dim] (3D)
auto
q_shape
=
q_desc
->
shape
();
if
(
q_shape
.
size
()
<
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
total_q_tokens
=
q_shape
[
0
];
size_t
num_heads
=
q_shape
[
q_shape
.
size
()
-
2
];
size_t
head_size
=
q_shape
[
q_shape
.
size
()
-
1
];
if
(
head_size
!=
128
)
{
std
::
cerr
<<
"[Error] PagedAttentionPrefill head_size = 128 supported, got "
<<
head_size
<<
std
::
endl
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// 从 seq_lens 获取 num_seqs
size_t
num_seqs
=
seq_lens_desc
->
shape
()[
0
];
auto
k_cache_shape
=
k_cache_desc
->
shape
();
size_t
num_kv_heads
=
k_cache_shape
[
1
];
size_t
block_size
=
v_cache_desc
->
shape
()[
2
];
size_t
max_num_blocks_per_seq
=
block_tables_desc
->
shape
()[
1
];
// 提取步长,需要保持多个请求的 Q 连续
ptrdiff_t
q_stride
=
q_desc
->
stride
(
0
);
ptrdiff_t
kv_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
kv_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
o_stride
=
out_desc
->
stride
(
0
);
return
utils
::
Result
<
PagedAttentionPrefillInfo
>
(
PagedAttentionPrefillInfo
{
dtype
,
scale
,
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_num_blocks_per_seq
,
total_q_tokens
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
o_stride
});
}
};
}
// namespace op::paged_attention_prefill
#endif
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
0 → 100644
View file @
1ba0bcfa
#include <cuda_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_prefill_nvidia.cuh"
// ==============================================================================
// Host wrapper to launch the global kernel
// ==============================================================================
template
<
typename
Tdata
,
typename
Tcompute
>
infiniStatus_t
launchPagedAttentionPrefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
int64_t
*
block_tables
,
const
int64_t
*
cache_lens
,
const
int64_t
*
seq_lens
,
const
int64_t
*
offset
,
const
float
*
alibi_slopes
,
const
size_t
num_heads
,
const
size_t
num_seqs
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
size_t
total_q_tokens
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o_stride
,
const
size_t
head_size
,
cudaStream_t
stream
)
{
if
(
total_q_tokens
==
0
||
num_heads
==
0
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// 使用 2D Grid: X轴是所有 Token,Y轴是所有 Head
dim3
grid
(
total_q_tokens
,
num_heads
);
dim3
block
(
head_size
);
op
::
paged_attention_prefill
::
cuda
::
pagedAttentionPrefillKernel
<
Tdata
,
Tcompute
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
cache_lens
,
seq_lens
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
block_size
,
kv_block_stride
,
kv_head_stride
,
head_size
,
offset
,
num_seqs
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
std
::
cerr
<<
"CUDA Kernel Launch Failed: "
<<
cudaGetErrorString
(
err
)
<<
std
::
endl
;
return
INFINI_STATUS_INTERNAL_ERROR
;
}
return
INFINI_STATUS_SUCCESS
;
}
namespace
op
::
paged_attention_prefill
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
cache_lens_desc
,
seq_lens_desc
,
offset_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
offset
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_info
.
head_size
>
1024
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)cache_lens, (const int64_t *)seq_lens, \
(const int64_t *)offset, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
_info.head_size, \
stream)
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
return
LAUNCH_KERNEL
(
half
,
float
);
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_BF16
)
{
return
LAUNCH_KERNEL
(
__nv_bfloat16
,
float
);
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_F32
)
{
return
LAUNCH_KERNEL
(
float
,
float
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
// namespace op::paged_attention_prefill::nvidia
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh
0 → 100644
View file @
1ba0bcfa
#ifndef __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#define __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR
(
nvidia
)
#endif // __PAGED_ATTENTION_PREFILL_NVIDIA_H__
src/infiniop/ops/paged_attention_prefill/operator.cc
0 → 100644
View file @
1ba0bcfa
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedAttentionPrefillDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
cache_lens_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
offset_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
)
{
infiniopTensorDescriptor_t
alibi_opt
=
(
alibi_slopes_desc
==
nullptr
)
?
nullptr
:
alibi_slopes_desc
;
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, \
seq_lens_desc, offset_desc, alibi_opt, scale);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetPagedAttentionPrefillWorkspaceSize
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopPagedAttentionPrefill
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
cache_lens
,
const
void
*
seq_lens
,
const
void
*
offset
,
const
void
*
alibi_slopes
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
cache_lens, seq_lens, offset, alibi_slopes, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyPagedAttentionPrefillDescriptor
(
infiniopPagedAttentionPrefillDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
0 → 100644
View file @
1ba0bcfa
#ifndef PAGED_ATTENTION_PREFILL_H
#define PAGED_ATTENTION_PREFILL_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t cache_lens_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t offset_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, const void *cache_lens, const void *seq_lens, \
const void *offset, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_PREFILL_H
test/infiniop/libinfiniop/op_register.py
View file @
1ba0bcfa
...
...
@@ -939,6 +939,7 @@ def tanh_(lib):
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
scaled_mm_int8_
(
lib
):
lib
.
infiniopCreateI8GemmDescriptor
.
restype
=
c_int32
...
...
@@ -1061,3 +1062,50 @@ def paged_caching_(lib):
lib
.
infiniopDestroyPagedCachingDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
paged_attention_prefill_
(
lib
):
lib
.
infiniopCreatePagedAttentionPrefillDescriptor
.
restype
=
c_int32
lib
.
infiniopCreatePagedAttentionPrefillDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
c_float
,
]
lib
.
infiniopGetPagedAttentionPrefillWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetPagedAttentionPrefillWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopPagedAttentionPrefill
.
restype
=
c_int32
lib
.
infiniopPagedAttentionPrefill
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyPagedAttentionPrefillDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyPagedAttentionPrefillDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
test/infiniop/paged_attention_prefill.py
0 → 100644
View file @
1ba0bcfa
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES
=
[
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds
(
1
,
1
,
1
,
128
,
8
,
16
,
1
),
(
1
,
4
,
4
,
128
,
8
,
16
,
4
),
(
2
,
8
,
8
,
128
,
16
,
32
,
2
),
(
4
,
16
,
16
,
128
,
8
,
64
,
3
),
(
8
,
64
,
64
,
128
,
8
,
16
,
5
),
(
16
,
128
,
128
,
128
,
8
,
16
,
4
),
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
F32
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
_TOLERANCE_MAP
=
{
InfiniDtype
.
F32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
InfiniDtype
.
F16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
2e-2
,
"rtol"
:
2e-2
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
5
NUM_ITERATIONS
=
10
# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class
SimpleCacheManager
:
def
__init__
(
self
,
num_blocks
,
block_size
):
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
free_blocks
=
list
(
range
(
num_blocks
))
self
.
request_to_blocks
=
{}
self
.
request_to_len
=
{}
def
allocate_slots
(
self
,
request_id
,
num_new_tokens
):
if
request_id
not
in
self
.
request_to_len
:
self
.
request_to_len
[
request_id
]
=
0
self
.
request_to_blocks
[
request_id
]
=
[]
start_pos
=
self
.
request_to_len
[
request_id
]
new_total_len
=
start_pos
+
num_new_tokens
needed_blocks
=
(
new_total_len
+
self
.
block_size
-
1
)
//
self
.
block_size
added_blocks
=
needed_blocks
-
len
(
self
.
request_to_blocks
[
request_id
])
for
_
in
range
(
added_blocks
):
self
.
request_to_blocks
[
request_id
].
append
(
self
.
free_blocks
.
pop
(
0
))
self
.
request_to_len
[
request_id
]
=
new_total_len
return
self
.
request_to_blocks
[
request_id
],
new_total_len
def
ref_paged_attention_multi_turn
(
query_new
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
new_lens
,
offset
,
scale
):
block_size
=
k_cache
.
shape
[
2
]
outputs
=
torch
.
zeros_like
(
query_new
)
for
i
in
range
(
len
(
offset
)
-
1
):
total_len
=
seq_lens
[
i
].
item
()
num_new
=
new_lens
[
i
].
item
()
history_len
=
total_len
-
num_new
table
=
block_tables
[
i
]
keys_all
,
values_all
=
[],
[]
for
j
in
range
(
total_len
):
b_id
=
table
[
j
//
block_size
].
item
()
off
=
j
%
block_size
keys_all
.
append
(
k_cache
[
b_id
,
:,
off
,
:])
values_all
.
append
(
v_cache
[
b_id
,
:,
off
,
:])
K
=
torch
.
stack
(
keys_all
,
dim
=
0
)
V
=
torch
.
stack
(
values_all
,
dim
=
0
)
Q
=
query_new
[
offset
[
i
]
:
offset
[
i
]
+
num_new
,
:,
:]
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
Q
,
K
).
float
()
*
scale
mask
=
torch
.
full
((
num_new
,
total_len
),
float
(
"-inf"
),
device
=
Q
.
device
)
for
q_idx
in
range
(
num_new
):
mask
[
q_idx
,
:
history_len
+
q_idx
+
1
]
=
0.0
scores
=
scores
+
mask
.
unsqueeze
(
0
)
attn_weights
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
Q
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
V
)
outputs
[
offset
[
i
]
:
offset
[
i
]
+
num_new
,
:,
:]
=
out
return
outputs
# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def
test
(
handle
,
device
,
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_step_len
,
num_rounds
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing PagedAttentionPrefill on
{
InfiniDeviceNames
[
device
]
}
with "
f
"seqs:
{
num_seqs
}
, heads:
{
num_heads
}
, head_size:
{
head_size
}
, "
f
"block:
{
block_size
}
, max_step_len:
{
max_step_len
}
, num_rounds:
{
num_rounds
}
, dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
)
# 1. Initialize persistent resources
num_blocks
=
8192
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
scale
=
head_size
**-
0.5
k_cache
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
v_cache
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
# Multi-turn testing loop
for
r
in
range
(
num_rounds
):
# Prepare dynamic inputs for this round
seq_lens_cpu
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
q_total_tokens
=
seq_lens_cpu
.
sum
().
item
()
q_packed_tensors
=
torch
.
zeros
(
q_total_tokens
,
num_heads
,
head_size
)
cache_lens_list
=
[]
all_block_tables
=
[]
offset_list
=
[]
cur_offset
=
0
for
i
in
range
(
num_seqs
):
offset_list
.
append
(
cur_offset
)
cur_new_len
=
seq_lens_cpu
[
i
].
item
()
table
,
cache_len
=
manager
.
allocate_slots
(
i
,
cur_new_len
)
cache_lens_list
.
append
(
cache_len
)
all_block_tables
.
append
(
table
)
# Simulated KV insertion
k_new
=
torch
.
randn
(
cur_new_len
,
num_kv_heads
,
head_size
)
v_new
=
torch
.
randn
(
cur_new_len
,
num_kv_heads
,
head_size
)
q_val
=
torch
.
randn
(
cur_new_len
,
num_heads
,
head_size
)
q_packed_tensors
[
cur_offset
:
cur_offset
+
cur_new_len
]
=
q_val
cur_offset
=
cur_offset
+
cur_new_len
history_len
=
cache_len
-
cur_new_len
for
t
in
range
(
cur_new_len
):
logical_pos
=
history_len
+
t
b_id
=
table
[
logical_pos
//
block_size
]
off
=
logical_pos
%
block_size
k_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
k_new
[
t
]
v_cache
.
torch_tensor
()[
b_id
,
:,
off
,
:]
=
v_new
[
t
]
offset_list
.
append
(
cur_offset
)
k_cache
.
actual_tensor
().
copy_
(
k_cache
.
_torch_tensor
)
v_cache
.
actual_tensor
().
copy_
(
v_cache
.
_torch_tensor
)
# 2. Wrap tensors for Infiniop
q_new
=
TestTensor
.
from_torch
(
q_packed_tensors
,
dtype
,
device
)
out
=
TestTensor
.
from_torch
(
q_packed_tensors
,
dtype
,
device
)
out
.
actual_tensor
().
zero_
()
cache_lens
=
TestTensor
.
from_torch
(
torch
.
tensor
(
cache_lens_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
)
seq_lens
=
TestTensor
.
from_torch
(
seq_lens_cpu
,
InfiniDtype
.
I64
,
device
)
offset
=
TestTensor
.
from_torch
(
torch
.
tensor
(
offset_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
)
max_blocks
=
max
(
len
(
t
)
for
t
in
all_block_tables
)
padded_tables
=
[
t
+
[
0
]
*
(
max_blocks
-
len
(
t
))
for
t
in
all_block_tables
]
block_tables
=
TestTensor
.
from_torch
(
torch
.
tensor
(
padded_tables
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
)
# 3. Reference Calculation
def
torch_paged_attention_multi_turn
():
return
ref_paged_attention_multi_turn
(
q_new
.
torch_tensor
(),
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
block_tables
.
torch_tensor
(),
cache_lens
.
torch_tensor
(),
seq_lens
.
torch_tensor
(),
offset
.
torch_tensor
(),
scale
,
)
ans
=
torch_paged_attention_multi_turn
()
# 4. Infiniop Operator Execution
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreatePagedAttentionPrefillDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
out
.
descriptor
,
q_new
.
descriptor
,
k_cache
.
descriptor
,
v_cache
.
descriptor
,
block_tables
.
descriptor
,
cache_lens
.
descriptor
,
seq_lens
.
descriptor
,
offset
.
descriptor
,
None
,
# alibi_slopes_desc
scale
,
)
)
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetPagedAttentionPrefillWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
device
)
def
lib_attn
():
check_error
(
LIBINFINIOP
.
infiniopPagedAttentionPrefill
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
out
.
data
(),
q_new
.
data
(),
k_cache
.
data
(),
v_cache
.
data
(),
block_tables
.
data
(),
cache_lens
.
data
(),
seq_lens
.
data
(),
offset
.
data
(),
None
,
None
,
)
)
lib_attn
()
if
sync
:
sync
()
# 5. Validation
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling
if
PROFILE
:
profile_operation
(
f
"Torch_R
{
r
}
"
,
lambda
:
torch_paged_attention_multi_turn
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
,
)
profile_operation
(
f
" Lib_R
{
r
}
"
,
lambda
:
lib_attn
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
check_error
(
LIBINFINIOP
.
infiniopDestroyPagedAttentionPrefillDescriptor
(
descriptor
)
)
# ==============================================================================
# Main Execution
# ==============================================================================
if
__name__
==
"__main__"
:
args
=
get_args
()
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/paged_caching_prefill.py
0 → 100644
View file @
1ba0bcfa
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
TestWorkspace
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES
=
[
# num_seqs, max_step_len, num_kv_heads, head_size, block_size, num_rounds
(
1
,
16
,
1
,
128
,
8
,
5
),
(
2
,
64
,
8
,
128
,
16
,
2
),
(
8
,
128
,
32
,
128
,
16
,
3
),
(
5
,
512
,
40
,
128
,
16
,
3
),
(
16
,
64
,
8
,
128
,
32
,
1
),
(
10
,
256
,
40
,
128
,
32
,
3
),
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F32
]
_TOLERANCE_MAP
=
{
InfiniDtype
.
F32
:
{
"atol"
:
1e-8
,
"rtol"
:
1e-8
},
InfiniDtype
.
F16
:
{
"atol"
:
1e-8
,
"rtol"
:
1e-8
},
InfiniDtype
.
BF16
:
{
"atol"
:
1e-8
,
"rtol"
:
1e-8
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
5
NUM_ITERATIONS
=
10
# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class
SimpleCacheManager
:
def
__init__
(
self
,
num_blocks
,
block_size
):
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
free_blocks
=
list
(
range
(
num_blocks
))
self
.
request_to_blocks
=
{}
self
.
request_to_len
=
{}
def
allocate_slots
(
self
,
request_id
,
num_new_tokens
):
if
request_id
not
in
self
.
request_to_len
:
self
.
request_to_len
[
request_id
]
=
0
self
.
request_to_blocks
[
request_id
]
=
[]
start_pos
=
self
.
request_to_len
[
request_id
]
new_total_len
=
start_pos
+
num_new_tokens
needed_blocks
=
(
new_total_len
+
self
.
block_size
-
1
)
//
self
.
block_size
added_blocks
=
needed_blocks
-
len
(
self
.
request_to_blocks
[
request_id
])
for
_
in
range
(
added_blocks
):
self
.
request_to_blocks
[
request_id
].
append
(
self
.
free_blocks
.
pop
(
0
))
slots
=
[]
for
i
in
range
(
start_pos
,
new_total_len
):
block_idx_in_seq
=
i
//
self
.
block_size
block_offset
=
i
%
self
.
block_size
physical_block_id
=
self
.
request_to_blocks
[
request_id
][
block_idx_in_seq
]
slots
.
append
(
physical_block_id
*
self
.
block_size
+
block_offset
)
self
.
request_to_len
[
request_id
]
=
new_total_len
return
torch
.
tensor
(
slots
,
dtype
=
torch
.
int32
)
def
ref_paged_caching
(
k_new
,
v_new
,
k_pool
,
v_pool
,
slots
,
block_size
):
"""Reference implementation for incremental caching."""
for
i
in
range
(
k_new
.
shape
[
0
]):
slot
=
slots
[
i
].
item
()
b_id
=
slot
//
block_size
off
=
slot
%
block_size
k_pool
[
b_id
,
:,
off
,
:]
=
k_new
[
i
]
v_pool
[
b_id
,
:,
off
,
:]
=
v_new
[
i
]
return
k_pool
,
v_pool
# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def
test
(
handle
,
device
,
num_seqs
,
max_step_len
,
num_kv_heads
,
head_size
,
block_size
,
num_rounds
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing PagedCaching on
{
InfiniDeviceNames
[
device
]
}
with "
f
"seqs:
{
num_seqs
}
, max_step_len:
{
max_step_len
}
, num_kv_heads:
{
num_kv_heads
}
, head_size:
{
head_size
}
, "
f
"block_size:
{
block_size
}
, rounds:
{
num_rounds
}
, dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
)
# 1. Initialize Global Cache Pool
num_blocks
=
8192
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
k_cache_pool
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
v_cache_pool
=
TestTensor
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
# Reference pools (CPU/Torch)
k_pool_ref
=
k_cache_pool
.
torch_tensor
().
clone
()
v_pool_ref
=
v_cache_pool
.
torch_tensor
().
clone
()
for
r
in
range
(
num_rounds
):
# Prepare incremental data for this round
round_ntok_list
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int32
)
all_slots
,
all_k
,
all_v
=
[],
[],
[]
for
i
in
range
(
num_seqs
):
n_new
=
round_ntok_list
[
i
].
item
()
all_slots
.
append
(
manager
.
allocate_slots
(
i
,
n_new
))
all_k
.
append
(
torch
.
randn
(
n_new
,
num_kv_heads
,
head_size
))
all_v
.
append
(
torch
.
randn
(
n_new
,
num_kv_heads
,
head_size
))
k_in_torch
=
torch
.
cat
(
all_k
,
dim
=
0
)
v_in_torch
=
torch
.
cat
(
all_v
,
dim
=
0
)
slots_torch
=
torch
.
cat
(
all_slots
,
dim
=
0
)
k_in
=
TestTensor
.
from_torch
(
k_in_torch
,
dtype
,
device
)
v_in
=
TestTensor
.
from_torch
(
v_in_torch
,
dtype
,
device
)
slot_mapping
=
TestTensor
.
from_torch
(
slots_torch
,
InfiniDtype
.
I64
,
device
)
# 2. Reference Calculation
def
torch_caching
():
nonlocal
k_pool_ref
,
v_pool_ref
return
ref_paged_caching
(
k_in
.
torch_tensor
(),
v_in
.
torch_tensor
(),
k_pool_ref
,
v_pool_ref
,
slots_torch
,
block_size
,
)
torch_caching
()
# 3. Infiniop Operator Execution
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreatePagedCachingDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
k_in
.
descriptor
,
v_in
.
descriptor
,
k_cache_pool
.
descriptor
,
v_cache_pool
.
descriptor
,
slot_mapping
.
descriptor
,
)
)
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetPagedCachingWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
device
)
def
lib_caching
():
check_error
(
LIBINFINIOP
.
infiniopPagedCaching
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
k_in
.
data
(),
v_in
.
data
(),
k_cache_pool
.
data
(),
v_cache_pool
.
data
(),
slot_mapping
.
data
(),
None
,
)
)
lib_caching
()
if
sync
:
sync
()
# 4. Validation
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
# Check a small slice of the updated cache
debug
(
k_cache_pool
.
actual_tensor
(),
k_pool_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
k_cache_pool
.
actual_tensor
(),
k_pool_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
v_cache_pool
.
actual_tensor
(),
v_pool_ref
,
atol
=
atol
,
rtol
=
rtol
)
# 5. Profiling
if
PROFILE
:
profile_operation
(
f
"Torch_R
{
r
}
"
,
lambda
:
torch_caching
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
,
)
profile_operation
(
f
" Lib_R
{
r
}
"
,
lambda
:
lib_caching
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
check_error
(
LIBINFINIOP
.
infiniopDestroyPagedCachingDescriptor
(
descriptor
))
# ==============================================================================
# Main Execution
# ==============================================================================
if
__name__
==
"__main__"
:
args
=
get_args
()
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment