Commit 82155c76 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip cp_gather_and_upconvert_fp8_kv_cache

parent 13baa653
......@@ -58,13 +58,13 @@ void cp_gather_cache(
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// void cp_gather_and_upconvert_fp8_kv_cache(
// torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
// torch::Tensor const& dst, // [TOT_TOKENS, 576]
// torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
// torch::Tensor const& seq_lens, // [BATCH]
// torch::Tensor const& workspace_starts, // [BATCH]
// int64_t batch_size);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
......
......@@ -1007,70 +1007,70 @@ namespace vllm {
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
const int32_t* __restrict__ workspace_starts, // [num_reqs]
const int32_t num_reqs, const int32_t block_size,
const int32_t total_tokens, const int64_t block_table_stride,
const int64_t cache_block_stride, const int64_t cache_entry_stride,
const int64_t dst_entry_stride) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= total_tokens) return;
const int lane_id = threadIdx.x & 31;
// Binary search to find which request owns this output token
int lo = 0, hi = num_reqs - 1;
while (lo < hi) {
int mid = (lo + hi + 1) >> 1;
if (workspace_starts[mid] <= flat_warp_id)
lo = mid;
else
hi = mid - 1;
}
const int req_id = lo;
// Compute physical token address via block table
const int out_token_id = flat_warp_id;
const int token_offset = out_token_id - workspace_starts[req_id];
const int cache_block_idx = token_offset / block_size;
const int offset_in_block = token_offset % block_size;
const int physical_block =
block_table[req_id * block_table_stride + cache_block_idx];
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
offset_in_block * cache_entry_stride;
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
const int4 fp8_data = nope_src[lane_id];
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const float scale = scales_ptr[lane_id >> 3];
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
#else
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
#endif
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
rope_dst[lane_id] = rope_src[lane_id];
}
// __global__ void cp_gather_and_upconvert_fp8_kv_cache(
// const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
// __nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
// const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
// const int32_t* __restrict__ workspace_starts, // [num_reqs]
// const int32_t num_reqs, const int32_t block_size,
// const int32_t total_tokens, const int64_t block_table_stride,
// const int64_t cache_block_stride, const int64_t cache_entry_stride,
// const int64_t dst_entry_stride) {
// const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
// if (flat_warp_id >= total_tokens) return;
// const int lane_id = threadIdx.x & 31;
// // Binary search to find which request owns this output token
// int lo = 0, hi = num_reqs - 1;
// while (lo < hi) {
// int mid = (lo + hi + 1) >> 1;
// if (workspace_starts[mid] <= flat_warp_id)
// lo = mid;
// else
// hi = mid - 1;
// }
// const int req_id = lo;
// // Compute physical token address via block table
// const int out_token_id = flat_warp_id;
// const int token_offset = out_token_id - workspace_starts[req_id];
// const int cache_block_idx = token_offset / block_size;
// const int offset_in_block = token_offset % block_size;
// const int physical_block =
// block_table[req_id * block_table_stride + cache_block_idx];
// const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
// offset_in_block * cache_entry_stride;
// const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
// const int4 fp8_data = nope_src[lane_id];
// const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
// const float scale = scales_ptr[lane_id >> 3];
// const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
// const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
// #ifdef USE_ROCM
// const bf16_8_t bf16_lo =
// fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
// const bf16_8_t bf16_hi =
// fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
// #else
// const bf16_8_t bf16_lo =
// fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
// const bf16_8_t bf16_hi =
// fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
// #endif
// __nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
// int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
// nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
// nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
// const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
// int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
// rope_dst[lane_id] = rope_src[lane_id];
// }
template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
......@@ -1213,69 +1213,69 @@ void cp_gather_cache(
}
}
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32");
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
auto dtype = src_cache.scalar_type();
TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
const uint8_t* src_ptr = nullptr;
if (dtype == at::ScalarType::Byte) {
src_ptr = src_cache.data_ptr<uint8_t>();
} else {
// float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
}
const int total_tokens = dst.size(0);
constexpr int warps_per_block = 8;
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
const int block_size_threads = warps_per_block * 32; // 256 threads
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
stream>>>(
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
static_cast<int32_t>(batch_size), block_size, total_tokens,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}
// void cp_gather_and_upconvert_fp8_kv_cache(
// torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
// torch::Tensor const& dst, // [TOT_TOKENS, 576]
// torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
// torch::Tensor const& seq_lens, // [BATCH]
// torch::Tensor const& workspace_starts, // [BATCH]
// int64_t batch_size) {
// at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// int32_t block_size = src_cache.size(1);
// int32_t head_dim = dst.size(1);
// TORCH_CHECK(block_table.dtype() == torch::kInt32,
// "block_table must be int32");
// TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
// TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
// "workspace_starts must be int32");
// TORCH_CHECK(src_cache.device() == dst.device(),
// "src_cache and dst must be on the same device");
// TORCH_CHECK(src_cache.device() == block_table.device(),
// "src_cache and block_table must be on the same device");
// TORCH_CHECK(src_cache.device() == seq_lens.device(),
// "src_cache and seq_lens must be on the same device");
// TORCH_CHECK(src_cache.device() == workspace_starts.device(),
// "src_cache and workspace_starts must be on the same device");
// auto dtype = src_cache.scalar_type();
// TORCH_CHECK(
// dtype == at::ScalarType::Byte || // uint8
// dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
// dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
// "src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
// src_cache.dtype());
// TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
// TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
// int64_t block_table_stride = block_table.stride(0);
// int64_t cache_block_stride = src_cache.stride(0);
// int64_t cache_entry_stride = src_cache.stride(1);
// int64_t dst_entry_stride = dst.stride(0);
// const uint8_t* src_ptr = nullptr;
// if (dtype == at::ScalarType::Byte) {
// src_ptr = src_cache.data_ptr<uint8_t>();
// } else {
// // float8_e4m3fn or float8_e5m2
// src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
// }
// const int total_tokens = dst.size(0);
// constexpr int warps_per_block = 8;
// const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
// const int block_size_threads = warps_per_block * 32; // 256 threads
// vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
// stream>>>(
// src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
// block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
// static_cast<int32_t>(batch_size), block_size, total_tokens,
// block_table_stride, cache_block_stride, cache_entry_stride,
// dst_entry_stride);
// }
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
......
......@@ -8,6 +8,8 @@
#include <cassert>
#ifdef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
......
......@@ -799,12 +799,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);
// cache_ops.def(
// "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
// "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
// "batch_size) -> ()");
// cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
// &cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
......
......@@ -2713,27 +2713,27 @@ def cp_gather_cache(
)
def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
workspace_starts: torch.Tensor,
batch_size: int,
) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace.
Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
)
# def cp_gather_and_upconvert_fp8_kv_cache(
# src_cache: torch.Tensor,
# dst: torch.Tensor,
# block_table: torch.Tensor,
# seq_lens: torch.Tensor,
# workspace_starts: torch.Tensor,
# batch_size: int,
# ) -> None:
# """Gather and upconvert FP8 KV cache to BF16 workspace.
# Args:
# src_cache: FP8 KV cache [num_blocks, block_size, 656]
# dst: BF16 output workspace [total_tokens, 576]
# block_table: Block indices [num_reqs, max_blocks]
# seq_lens: Sequence lengths [num_reqs]
# workspace_starts: Workspace start offsets [num_reqs]
# batch_size: Number of requests
# """
# torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
# src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
# )
def concat_mla_q(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment