Commit 4ec64732 authored by zhuwenwen's avatar zhuwenwen
Browse files

add indexer_k_cache_kernel

parent 25ec6a34
......@@ -83,6 +83,13 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt);
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
......
......@@ -22,6 +22,7 @@
#include <cfloat>
#include <map>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
......@@ -808,6 +809,52 @@ __global__ void indexer_k_quant_and_cache_kernel(
}
}
template <typename scalar_t, typename cache_t>
__global__ void indexer_k_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int cache_block_size, // cache block size
const int cache_stride // stride for each token in kv_cache
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
float val = static_cast<float>(k_val_ptr[i]);
if constexpr (std::is_same<cache_t, at::Half>::value ||
std::is_same<cache_t, __half>::value) {
kv_cache[dst_offset + i] = __float2half(val);
} else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
std::is_same<cache_t, __nv_bfloat16>::value) {
kv_cache[dst_offset + i] = __float2bfloat16(val);
} else if constexpr (std::is_same<cache_t, float>::value) {
kv_cache[dst_offset + i] = val;
} else {
kv_cache[dst_offset + i] = static_cast<cache_t>(val);
}
}
}
template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size,
......@@ -1791,6 +1838,78 @@ void indexer_k_quant_and_cache(
CALL_INDEXER_K_QUANT_AND_CACHE);
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_CACHE(KV_T, CACHE_T) \
vllm::indexer_k_cache_kernel<KV_T, CACHE_T> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, \
cache_block_size, cache_stride);
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + vec_size - 1) / vec_size);
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
k.scalar_type(), "indexer_k_cache", ([&] {
using k_t = scalar_t;
auto kv_cache_type = kv_cache.scalar_type();
if (kv_cache_type == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<float>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::Half>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::BFloat16) {
vllm::indexer_k_cache_kernel<k_t, at::BFloat16>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::BFloat16>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
}
// Macro to dispatch the kernel based on the data amount.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
......
......@@ -822,6 +822,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_cache", torch::kCUDA,
&indexer_k_cache);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
......
......@@ -2615,6 +2615,13 @@ def indexer_k_quant_and_cache(
torch.ops._C_cache_ops.indexer_k_quant_and_cache(
k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype
)
def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_cache(
k, kv_cache, slot_mapping, kv_cache_dtype
)
def cp_gather_indexer_k_quant_cache(
......
......@@ -682,6 +682,13 @@ def sparse_attn_indexer(
quant_block_size,
scale_fmt,
)
else:
ops.indexer_k_cache(
k,
kv_cache,
slot_mapping,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment