Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a1937618
Commit
a1937618
authored
Apr 23, 2026
by
yaoht
Browse files
add pagedCachingBf16Head128Block256
parent
1b5a38de
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
108 additions
and
4 deletions
+108
-4
src/infiniop/ops/paged_caching/cuda/kernel.cuh
src/infiniop/ops/paged_caching/cuda/kernel.cuh
+61
-0
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+47
-4
No files found.
src/infiniop/ops/paged_caching/cuda/kernel.cuh
View file @
a1937618
...
...
@@ -85,6 +85,67 @@ __device__ void pagedCachingKernel(
}
}
#if !defined(ENABLE_MOORE_API) && !defined(ENABLE_METAX_API)
#if defined(__CUDACC__)
#include <vector_types.h>
__device__
__forceinline__
void
pagedCachingKernelBf16Head128Block256Vec
(
__nv_bfloat16
*
k_cache_ptr
,
__nv_bfloat16
*
v_cache_ptr
,
const
__nv_bfloat16
*
k_ptr
,
const
__nv_bfloat16
*
v_ptr
,
const
int64_t
*
slot_mapping_ptr
,
const
ptrdiff_t
k_src_stride
,
const
ptrdiff_t
v_src_stride
,
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
,
const
ptrdiff_t
k_cache_head_stride
,
const
ptrdiff_t
v_cache_head_stride
,
const
ptrdiff_t
k_cache_slot_stride
,
const
ptrdiff_t
v_cache_slot_stride
)
{
constexpr
int
DH_BF16
=
128
;
constexpr
int
VEC_BF16
=
8
;
constexpr
int
NUM_VEC
=
DH_BF16
/
VEC_BF16
;
const
int
token_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping_ptr
[
token_idx
];
if
(
slot_idx
<
0
)
{
return
;
}
const
int64_t
physical_block_idx
=
slot_idx
>>
8
;
const
int64_t
block_offset
=
slot_idx
&
int64_t
(
255
);
const
__nv_bfloat16
*
k_src_head
=
k_ptr
+
token_idx
*
k_src_stride
+
head_idx
*
DH_BF16
;
const
__nv_bfloat16
*
v_src_head
=
v_ptr
+
token_idx
*
v_src_stride
+
head_idx
*
DH_BF16
;
__nv_bfloat16
*
k_dst_head
=
k_cache_ptr
+
physical_block_idx
*
k_cache_block_stride
+
head_idx
*
k_cache_head_stride
+
block_offset
*
k_cache_slot_stride
;
__nv_bfloat16
*
v_dst_head
=
v_cache_ptr
+
physical_block_idx
*
v_cache_block_stride
+
head_idx
*
v_cache_head_stride
+
block_offset
*
v_cache_slot_stride
;
const
int
tid
=
threadIdx
.
x
;
if
(
tid
>=
NUM_VEC
)
{
return
;
}
const
int
offset_bf16
=
tid
*
VEC_BF16
;
const
uint4
*
pk
=
reinterpret_cast
<
const
uint4
*>
(
k_src_head
+
offset_bf16
);
const
uint4
*
pv
=
reinterpret_cast
<
const
uint4
*>
(
v_src_head
+
offset_bf16
);
uint4
*
qk
=
reinterpret_cast
<
uint4
*>
(
k_dst_head
+
offset_bf16
);
uint4
*
qv
=
reinterpret_cast
<
uint4
*>
(
v_dst_head
+
offset_bf16
);
uint4
t
=
*
pk
;
*
qk
=
t
;
t
=
*
pv
;
*
qv
=
t
;
}
#endif // __CUDACC__
#endif // !ENABLE_MOORE_API && !ENABLE_METAX_API
}
// namespace op::paged_caching::cuda
#endif // __PAGED_CACHING_KERNEL_CUH__
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
View file @
a1937618
...
...
@@ -19,6 +19,25 @@ INFINIOP_CUDA_KERNEL pagedCaching(
k_cache_block_stride
,
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
// BF16 dh=128 / page=256 / slot_stride=128:16 线程 × uint4 向量拷贝
#if !defined(ENABLE_MOORE_API) && !defined(ENABLE_METAX_API)
__global__
void
pagedCachingBf16Head128Block256
(
__nv_bfloat16
*
k_cache
,
__nv_bfloat16
*
v_cache
,
const
__nv_bfloat16
*
k
,
const
__nv_bfloat16
*
v
,
const
int64_t
*
slot_mapping
,
ptrdiff_t
k_src_stride
,
ptrdiff_t
v_src_stride
,
ptrdiff_t
k_cache_block_stride
,
ptrdiff_t
v_cache_block_stride
,
ptrdiff_t
k_cache_head_stride
,
ptrdiff_t
v_cache_head_stride
,
ptrdiff_t
k_cache_slot_stride
,
ptrdiff_t
v_cache_slot_stride
)
{
op
::
paged_caching
::
cuda
::
pagedCachingKernelBf16Head128Block256Vec
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
#endif
namespace
op
::
paged_caching
::
nvidia
{
// PIMPL struct definition
struct
Descriptor
::
Opaque
{
...
...
@@ -94,6 +113,29 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
#if !defined(ENABLE_MOORE_API) && !defined(ENABLE_METAX_API)
const
bool
bf16_vec
=
(
head_size
==
128
&&
block_size
==
256
&&
k_cache_slot_stride
==
128
&&
v_cache_slot_stride
==
128
);
if
(
bf16_vec
)
{
constexpr
unsigned
BF16_VEC_THREADS
=
16
;
dim3
block_vec
(
BF16_VEC_THREADS
);
pagedCachingBf16Head128Block256
<<<
grid
,
block_vec
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
k_cache
,
(
__nv_bfloat16
*
)
v_cache
,
(
const
__nv_bfloat16
*
)
k
,
(
const
__nv_bfloat16
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
#endif
{
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
k_cache
,
...
...
@@ -111,6 +153,7 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedCaching
<
float
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment