"cmake/vscode:/vscode.git/clone" did not exist on "0259837d0c458ba21e85d62a855ec0cfaa3a1924"
Commit 82155c76 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip cp_gather_and_upconvert_fp8_kv_cache

parent 13baa653
...@@ -58,13 +58,13 @@ void cp_gather_cache( ...@@ -58,13 +58,13 @@ void cp_gather_cache(
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt); int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace // Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache( // void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] // torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576] // torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] // torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH] // torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH] // torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size); // int64_t batch_size);
// Indexer K quantization and cache function // Indexer K quantization and cache function
void indexer_k_quant_and_cache( void indexer_k_quant_and_cache(
......
...@@ -1007,70 +1007,70 @@ namespace vllm { ...@@ -1007,70 +1007,70 @@ namespace vllm {
// Gather and upconvert FP8 KV cache tokens to BF16 workspace // Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion // Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache( // __global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] // const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576] // __nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES] // const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
const int32_t* __restrict__ workspace_starts, // [num_reqs] // const int32_t* __restrict__ workspace_starts, // [num_reqs]
const int32_t num_reqs, const int32_t block_size, // const int32_t num_reqs, const int32_t block_size,
const int32_t total_tokens, const int64_t block_table_stride, // const int32_t total_tokens, const int64_t block_table_stride,
const int64_t cache_block_stride, const int64_t cache_entry_stride, // const int64_t cache_block_stride, const int64_t cache_entry_stride,
const int64_t dst_entry_stride) { // const int64_t dst_entry_stride) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; // const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= total_tokens) return; // if (flat_warp_id >= total_tokens) return;
const int lane_id = threadIdx.x & 31; // const int lane_id = threadIdx.x & 31;
// Binary search to find which request owns this output token // // Binary search to find which request owns this output token
int lo = 0, hi = num_reqs - 1; // int lo = 0, hi = num_reqs - 1;
while (lo < hi) { // while (lo < hi) {
int mid = (lo + hi + 1) >> 1; // int mid = (lo + hi + 1) >> 1;
if (workspace_starts[mid] <= flat_warp_id) // if (workspace_starts[mid] <= flat_warp_id)
lo = mid; // lo = mid;
else // else
hi = mid - 1; // hi = mid - 1;
} // }
const int req_id = lo; // const int req_id = lo;
// Compute physical token address via block table // // Compute physical token address via block table
const int out_token_id = flat_warp_id; // const int out_token_id = flat_warp_id;
const int token_offset = out_token_id - workspace_starts[req_id]; // const int token_offset = out_token_id - workspace_starts[req_id];
const int cache_block_idx = token_offset / block_size; // const int cache_block_idx = token_offset / block_size;
const int offset_in_block = token_offset % block_size; // const int offset_in_block = token_offset % block_size;
const int physical_block = // const int physical_block =
block_table[req_id * block_table_stride + cache_block_idx]; // block_table[req_id * block_table_stride + cache_block_idx];
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride + // const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
offset_in_block * cache_entry_stride; // offset_in_block * cache_entry_stride;
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr); // const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
const int4 fp8_data = nope_src[lane_id]; // const int4 fp8_data = nope_src[lane_id];
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512); // const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const float scale = scales_ptr[lane_id >> 3]; // const float scale = scales_ptr[lane_id >> 3];
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y); // const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w); // const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM // #ifdef USE_ROCM
const bf16_8_t bf16_lo = // const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale); // fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
const bf16_8_t bf16_hi = // const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale); // fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
#else // #else
const bf16_8_t bf16_lo = // const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3); // fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
const bf16_8_t bf16_hi = // const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3); // fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
#endif // #endif
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride; // __nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2; // int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo); // nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi); // nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528); // const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512); // int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
rope_dst[lane_id] = rope_src[lane_id]; // rope_dst[lane_id] = rope_src[lane_id];
} // }
template <typename scalar_t> template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
...@@ -1213,69 +1213,69 @@ void cp_gather_cache( ...@@ -1213,69 +1213,69 @@ void cp_gather_cache(
} }
} }
void cp_gather_and_upconvert_fp8_kv_cache( // void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] // torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576] // torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] // torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH] // torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH] // torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) { // int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); // at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1); // int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1); // int32_t head_dim = dst.size(1);
TORCH_CHECK(block_table.dtype() == torch::kInt32, // TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32"); // "block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); // TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, // TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32"); // "workspace_starts must be int32");
TORCH_CHECK(src_cache.device() == dst.device(), // TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device"); // "src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(), // TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device"); // "src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(), // TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device"); // "src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(), // TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device"); // "src_cache and workspace_starts must be on the same device");
auto dtype = src_cache.scalar_type(); // auto dtype = src_cache.scalar_type();
TORCH_CHECK( // TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8 // dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3 // dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2 // dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ", // "src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype()); // src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); // TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); // TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
int64_t block_table_stride = block_table.stride(0); // int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0); // int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1); // int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0); // int64_t dst_entry_stride = dst.stride(0);
const uint8_t* src_ptr = nullptr; // const uint8_t* src_ptr = nullptr;
if (dtype == at::ScalarType::Byte) { // if (dtype == at::ScalarType::Byte) {
src_ptr = src_cache.data_ptr<uint8_t>(); // src_ptr = src_cache.data_ptr<uint8_t>();
} else { // } else {
// float8_e4m3fn or float8_e5m2 // // float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr()); // src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
} // }
const int total_tokens = dst.size(0); // const int total_tokens = dst.size(0);
constexpr int warps_per_block = 8; // constexpr int warps_per_block = 8;
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block; // const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
const int block_size_threads = warps_per_block * 32; // 256 threads // const int block_size_threads = warps_per_block * 32; // 256 threads
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0, // vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
stream>>>( // stream>>>(
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), // src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(), // block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
static_cast<int32_t>(batch_size), block_size, total_tokens, // static_cast<int32_t>(batch_size), block_size, total_tokens,
block_table_stride, cache_block_stride, cache_entry_stride, // block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride); // dst_entry_stride);
} // }
// Macro to dispatch the kernel based on the data type. // Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <cassert> #include <cassert>
#ifdef USE_ROCM #ifdef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#else #else
#include <cuda_bf16.h> #include <cuda_bf16.h>
......
...@@ -799,12 +799,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -799,12 +799,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def( // cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, " // "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int " // "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()"); // "batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA, // cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache); // &cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def( cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
......
...@@ -2713,27 +2713,27 @@ def cp_gather_cache( ...@@ -2713,27 +2713,27 @@ def cp_gather_cache(
) )
def cp_gather_and_upconvert_fp8_kv_cache( # def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor, # src_cache: torch.Tensor,
dst: torch.Tensor, # dst: torch.Tensor,
block_table: torch.Tensor, # block_table: torch.Tensor,
seq_lens: torch.Tensor, # seq_lens: torch.Tensor,
workspace_starts: torch.Tensor, # workspace_starts: torch.Tensor,
batch_size: int, # batch_size: int,
) -> None: # ) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace. # """Gather and upconvert FP8 KV cache to BF16 workspace.
Args: # Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656] # src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576] # dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks] # block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs] # seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs] # workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests # batch_size: Number of requests
""" # """
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( # torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size # src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
) # )
def concat_mla_q( def concat_mla_q(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment