Commit a1937618 authored by yaoht's avatar yaoht
Browse files

add pagedCachingBf16Head128Block256

parent 1b5a38de
...@@ -85,6 +85,67 @@ __device__ void pagedCachingKernel( ...@@ -85,6 +85,67 @@ __device__ void pagedCachingKernel(
} }
} }
#if !defined(ENABLE_MOORE_API) && !defined(ENABLE_METAX_API)
#if defined(__CUDACC__)
#include <vector_types.h>
__device__ __forceinline__ void pagedCachingKernelBf16Head128Block256Vec(
__nv_bfloat16 *k_cache_ptr,
__nv_bfloat16 *v_cache_ptr,
const __nv_bfloat16 *k_ptr,
const __nv_bfloat16 *v_ptr,
const int64_t *slot_mapping_ptr,
const ptrdiff_t k_src_stride,
const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride,
const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride,
const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride,
const ptrdiff_t v_cache_slot_stride) {
constexpr int DH_BF16 = 128;
constexpr int VEC_BF16 = 8;
constexpr int NUM_VEC = DH_BF16 / VEC_BF16;
const int token_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping_ptr[token_idx];
if (slot_idx < 0) {
return;
}
const int64_t physical_block_idx = slot_idx >> 8;
const int64_t block_offset = slot_idx & int64_t(255);
const __nv_bfloat16 *k_src_head = k_ptr + token_idx * k_src_stride + head_idx * DH_BF16;
const __nv_bfloat16 *v_src_head = v_ptr + token_idx * v_src_stride + head_idx * DH_BF16;
__nv_bfloat16 *k_dst_head = k_cache_ptr + physical_block_idx * k_cache_block_stride
+ head_idx * k_cache_head_stride + block_offset * k_cache_slot_stride;
__nv_bfloat16 *v_dst_head = v_cache_ptr + physical_block_idx * v_cache_block_stride
+ head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride;
const int tid = threadIdx.x;
if (tid >= NUM_VEC) {
return;
}
const int offset_bf16 = tid * VEC_BF16;
const uint4 *pk = reinterpret_cast<const uint4 *>(k_src_head + offset_bf16);
const uint4 *pv = reinterpret_cast<const uint4 *>(v_src_head + offset_bf16);
uint4 *qk = reinterpret_cast<uint4 *>(k_dst_head + offset_bf16);
uint4 *qv = reinterpret_cast<uint4 *>(v_dst_head + offset_bf16);
uint4 t = *pk;
*qk = t;
t = *pv;
*qv = t;
}
#endif // __CUDACC__
#endif // !ENABLE_MOORE_API && !ENABLE_METAX_API
} // namespace op::paged_caching::cuda } // namespace op::paged_caching::cuda
#endif // __PAGED_CACHING_KERNEL_CUH__ #endif // __PAGED_CACHING_KERNEL_CUH__
...@@ -19,6 +19,25 @@ INFINIOP_CUDA_KERNEL pagedCaching( ...@@ -19,6 +19,25 @@ INFINIOP_CUDA_KERNEL pagedCaching(
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride); k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
// BF16 dh=128 / page=256 / slot_stride=128:16 线程 × uint4 向量拷贝
#if !defined(ENABLE_MOORE_API) && !defined(ENABLE_METAX_API)
__global__ void pagedCachingBf16Head128Block256(
__nv_bfloat16 *k_cache, __nv_bfloat16 *v_cache,
const __nv_bfloat16 *k, const __nv_bfloat16 *v,
const int64_t *slot_mapping,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernelBf16Head128Block256Vec(
k_cache, v_cache, k, v, slot_mapping,
k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride,
k_cache_head_stride, v_cache_head_stride,
k_cache_slot_stride, v_cache_slot_stride);
}
#endif
namespace op::paged_caching::nvidia { namespace op::paged_caching::nvidia {
// PIMPL struct definition // PIMPL struct definition
struct Descriptor::Opaque { struct Descriptor::Opaque {
...@@ -94,15 +113,18 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -94,15 +113,18 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_cache_slot_stride, k_cache_slot_stride,
v_cache_slot_stride); v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS> #if !defined(ENABLE_MOORE_API) && !defined(ENABLE_METAX_API)
<<<grid, block, shared_mem_size, stream>>>( const bool bf16_vec = (head_size == 128 && block_size == 256 && k_cache_slot_stride == 128
&& v_cache_slot_stride == 128);
if (bf16_vec) {
constexpr unsigned BF16_VEC_THREADS = 16;
dim3 block_vec(BF16_VEC_THREADS);
pagedCachingBf16Head128Block256<<<grid, block_vec, shared_mem_size, stream>>>(
(__nv_bfloat16 *)k_cache, (__nv_bfloat16 *)k_cache,
(__nv_bfloat16 *)v_cache, (__nv_bfloat16 *)v_cache,
(const __nv_bfloat16 *)k, (const __nv_bfloat16 *)k,
(const __nv_bfloat16 *)v, (const __nv_bfloat16 *)v,
(const int64_t *)slot_mapping, (const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
...@@ -111,6 +133,27 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,6 +133,27 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
v_cache_head_stride, v_cache_head_stride,
k_cache_slot_stride, k_cache_slot_stride,
v_cache_slot_stride); v_cache_slot_stride);
} else
#endif
{
pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)k_cache,
(__nv_bfloat16 *)v_cache,
(const __nv_bfloat16 *)k,
(const __nv_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
}
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment