Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82155c76
Commit
82155c76
authored
Mar 24, 2026
by
zhuwenwen
Browse files
skip cp_gather_and_upconvert_fp8_kv_cache
parent
13baa653
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
163 additions
and
161 deletions
+163
-161
csrc/cache.h
csrc/cache.h
+7
-7
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+127
-127
csrc/cuda_vec_utils.cuh
csrc/cuda_vec_utils.cuh
+2
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-6
vllm/_custom_ops.py
vllm/_custom_ops.py
+21
-21
No files found.
csrc/cache.h
View file @
82155c76
...
...
@@ -58,13 +58,13 @@ void cp_gather_cache(
int64_t
batch_size
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
);
// Gather and upconvert FP8 KV cache to BF16 workspace
void
cp_gather_and_upconvert_fp8_kv_cache
(
torch
::
Tensor
const
&
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, 656]
torch
::
Tensor
const
&
dst
,
// [TOT_TOKENS, 576]
torch
::
Tensor
const
&
block_table
,
// [BATCH, BLOCK_INDICES]
torch
::
Tensor
const
&
seq_lens
,
// [BATCH]
torch
::
Tensor
const
&
workspace_starts
,
// [BATCH]
int64_t
batch_size
);
//
void cp_gather_and_upconvert_fp8_kv_cache(
//
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
//
torch::Tensor const& dst, // [TOT_TOKENS, 576]
//
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
//
torch::Tensor const& seq_lens, // [BATCH]
//
torch::Tensor const& workspace_starts, // [BATCH]
//
int64_t batch_size);
// Indexer K quantization and cache function
void
indexer_k_quant_and_cache
(
...
...
csrc/cache_kernels.cu
View file @
82155c76
...
...
@@ -1007,70 +1007,70 @@ namespace vllm {
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__
void
cp_gather_and_upconvert_fp8_kv_cache
(
const
uint8_t
*
__restrict__
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16
*
__restrict__
dst
,
// [total_tokens, 576]
const
int32_t
*
__restrict__
block_table
,
// [num_reqs, BLOCK_INDICES]
const
int32_t
*
__restrict__
workspace_starts
,
// [num_reqs]
const
int32_t
num_reqs
,
const
int32_t
block_size
,
const
int32_t
total_tokens
,
const
int64_t
block_table_stride
,
const
int64_t
cache_block_stride
,
const
int64_t
cache_entry_stride
,
const
int64_t
dst_entry_stride
)
{
const
int
flat_warp_id
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
>>
5
;
if
(
flat_warp_id
>=
total_tokens
)
return
;
const
int
lane_id
=
threadIdx
.
x
&
31
;
// Binary search to find which request owns this output token
int
lo
=
0
,
hi
=
num_reqs
-
1
;
while
(
lo
<
hi
)
{
int
mid
=
(
lo
+
hi
+
1
)
>>
1
;
if
(
workspace_starts
[
mid
]
<=
flat_warp_id
)
lo
=
mid
;
else
hi
=
mid
-
1
;
}
const
int
req_id
=
lo
;
// Compute physical token address via block table
const
int
out_token_id
=
flat_warp_id
;
const
int
token_offset
=
out_token_id
-
workspace_starts
[
req_id
];
const
int
cache_block_idx
=
token_offset
/
block_size
;
const
int
offset_in_block
=
token_offset
%
block_size
;
const
int
physical_block
=
block_table
[
req_id
*
block_table_stride
+
cache_block_idx
];
const
uint8_t
*
token_ptr
=
src_cache
+
physical_block
*
cache_block_stride
+
offset_in_block
*
cache_entry_stride
;
const
int4
*
nope_src
=
reinterpret_cast
<
const
int4
*>
(
token_ptr
);
const
int4
fp8_data
=
nope_src
[
lane_id
];
const
float
*
scales_ptr
=
reinterpret_cast
<
const
float
*>
(
token_ptr
+
512
);
const
float
scale
=
scales_ptr
[
lane_id
>>
3
];
const
uint2
fp8_lo
=
make_uint2
(
fp8_data
.
x
,
fp8_data
.
y
);
const
uint2
fp8_hi
=
make_uint2
(
fp8_data
.
z
,
fp8_data
.
w
);
#ifdef USE_ROCM
const
bf16_8_t
bf16_lo
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_lo
,
scale
);
const
bf16_8_t
bf16_hi
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_hi
,
scale
);
#else
const
bf16_8_t
bf16_lo
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_lo
,
scale
,
__NV_E4M3
);
const
bf16_8_t
bf16_hi
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_hi
,
scale
,
__NV_E4M3
);
#endif
__nv_bfloat16
*
dst_ptr
=
dst
+
out_token_id
*
dst_entry_stride
;
int4
*
nope_dst
=
reinterpret_cast
<
int4
*>
(
dst_ptr
)
+
lane_id
*
2
;
nope_dst
[
0
]
=
*
reinterpret_cast
<
const
int4
*>
(
&
bf16_lo
);
nope_dst
[
1
]
=
*
reinterpret_cast
<
const
int4
*>
(
&
bf16_hi
);
const
int
*
rope_src
=
reinterpret_cast
<
const
int
*>
(
token_ptr
+
528
);
int
*
rope_dst
=
reinterpret_cast
<
int
*>
(
dst_ptr
+
512
);
rope_dst
[
lane_id
]
=
rope_src
[
lane_id
];
}
//
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
//
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
//
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
//
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
//
const int32_t* __restrict__ workspace_starts, // [num_reqs]
//
const int32_t num_reqs, const int32_t block_size,
//
const int32_t total_tokens, const int64_t block_table_stride,
//
const int64_t cache_block_stride, const int64_t cache_entry_stride,
//
const int64_t dst_entry_stride) {
//
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
//
if (flat_warp_id >= total_tokens) return;
//
const int lane_id = threadIdx.x & 31;
//
// Binary search to find which request owns this output token
//
int lo = 0, hi = num_reqs - 1;
//
while (lo < hi) {
//
int mid = (lo + hi + 1) >> 1;
//
if (workspace_starts[mid] <= flat_warp_id)
//
lo = mid;
//
else
//
hi = mid - 1;
//
}
//
const int req_id = lo;
//
// Compute physical token address via block table
//
const int out_token_id = flat_warp_id;
//
const int token_offset = out_token_id - workspace_starts[req_id];
//
const int cache_block_idx = token_offset / block_size;
//
const int offset_in_block = token_offset % block_size;
//
const int physical_block =
//
block_table[req_id * block_table_stride + cache_block_idx];
//
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
//
offset_in_block * cache_entry_stride;
//
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
//
const int4 fp8_data = nope_src[lane_id];
//
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
//
const float scale = scales_ptr[lane_id >> 3];
//
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
//
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
//
#ifdef USE_ROCM
//
const bf16_8_t bf16_lo =
//
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
//
const bf16_8_t bf16_hi =
//
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
//
#else
//
const bf16_8_t bf16_lo =
//
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
//
const bf16_8_t bf16_hi =
//
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
//
#endif
//
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
//
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
//
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
//
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
//
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
//
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
//
rope_dst[lane_id] = rope_src[lane_id];
//
}
template
<
typename
scalar_t
>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
...
...
@@ -1213,69 +1213,69 @@ void cp_gather_cache(
}
}
void
cp_gather_and_upconvert_fp8_kv_cache
(
torch
::
Tensor
const
&
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, 656]
torch
::
Tensor
const
&
dst
,
// [TOT_TOKENS, 576]
torch
::
Tensor
const
&
block_table
,
// [BATCH, BLOCK_INDICES]
torch
::
Tensor
const
&
seq_lens
,
// [BATCH]
torch
::
Tensor
const
&
workspace_starts
,
// [BATCH]
int64_t
batch_size
)
{
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_cache
.
device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int32_t
block_size
=
src_cache
.
size
(
1
);
int32_t
head_dim
=
dst
.
size
(
1
);
TORCH_CHECK
(
block_table
.
dtype
()
==
torch
::
kInt32
,
"block_table must be int32"
);
TORCH_CHECK
(
seq_lens
.
dtype
()
==
torch
::
kInt32
,
"seq_lens must be int32"
);
TORCH_CHECK
(
workspace_starts
.
dtype
()
==
torch
::
kInt32
,
"workspace_starts must be int32"
);
TORCH_CHECK
(
src_cache
.
device
()
==
dst
.
device
(),
"src_cache and dst must be on the same device"
);
TORCH_CHECK
(
src_cache
.
device
()
==
block_table
.
device
(),
"src_cache and block_table must be on the same device"
);
TORCH_CHECK
(
src_cache
.
device
()
==
seq_lens
.
device
(),
"src_cache and seq_lens must be on the same device"
);
TORCH_CHECK
(
src_cache
.
device
()
==
workspace_starts
.
device
(),
"src_cache and workspace_starts must be on the same device"
);
auto
dtype
=
src_cache
.
scalar_type
();
TORCH_CHECK
(
dtype
==
at
::
ScalarType
::
Byte
||
// uint8
dtype
==
at
::
ScalarType
::
Float8_e4m3fn
||
// fp8 e4m3
dtype
==
at
::
ScalarType
::
Float8_e5m2
,
// fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got "
,
src_cache
.
dtype
());
TORCH_CHECK
(
dst
.
dtype
()
==
torch
::
kBFloat16
,
"dst must be bfloat16"
);
TORCH_CHECK
(
head_dim
==
576
,
"head_dim must be 576 for MLA"
);
int64_t
block_table_stride
=
block_table
.
stride
(
0
);
int64_t
cache_block_stride
=
src_cache
.
stride
(
0
);
int64_t
cache_entry_stride
=
src_cache
.
stride
(
1
);
int64_t
dst_entry_stride
=
dst
.
stride
(
0
);
const
uint8_t
*
src_ptr
=
nullptr
;
if
(
dtype
==
at
::
ScalarType
::
Byte
)
{
src_ptr
=
src_cache
.
data_ptr
<
uint8_t
>
();
}
else
{
// float8_e4m3fn or float8_e5m2
src_ptr
=
reinterpret_cast
<
const
uint8_t
*>
(
src_cache
.
data_ptr
());
}
const
int
total_tokens
=
dst
.
size
(
0
);
constexpr
int
warps_per_block
=
8
;
const
int
grid_size
=
(
total_tokens
+
warps_per_block
-
1
)
/
warps_per_block
;
const
int
block_size_threads
=
warps_per_block
*
32
;
// 256 threads
vllm
::
cp_gather_and_upconvert_fp8_kv_cache
<<<
grid_size
,
block_size_threads
,
0
,
stream
>>>
(
src_ptr
,
reinterpret_cast
<
__nv_bfloat16
*>
(
dst
.
data_ptr
()),
block_table
.
data_ptr
<
int32_t
>
(),
workspace_starts
.
data_ptr
<
int32_t
>
(),
static_cast
<
int32_t
>
(
batch_size
),
block_size
,
total_tokens
,
block_table_stride
,
cache_block_stride
,
cache_entry_stride
,
dst_entry_stride
);
}
//
void cp_gather_and_upconvert_fp8_kv_cache(
//
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
//
torch::Tensor const& dst, // [TOT_TOKENS, 576]
//
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
//
torch::Tensor const& seq_lens, // [BATCH]
//
torch::Tensor const& workspace_starts, // [BATCH]
//
int64_t batch_size) {
//
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
//
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
//
int32_t block_size = src_cache.size(1);
//
int32_t head_dim = dst.size(1);
//
TORCH_CHECK(block_table.dtype() == torch::kInt32,
//
"block_table must be int32");
//
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
//
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
//
"workspace_starts must be int32");
//
TORCH_CHECK(src_cache.device() == dst.device(),
//
"src_cache and dst must be on the same device");
//
TORCH_CHECK(src_cache.device() == block_table.device(),
//
"src_cache and block_table must be on the same device");
//
TORCH_CHECK(src_cache.device() == seq_lens.device(),
//
"src_cache and seq_lens must be on the same device");
//
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
//
"src_cache and workspace_starts must be on the same device");
//
auto dtype = src_cache.scalar_type();
//
TORCH_CHECK(
//
dtype == at::ScalarType::Byte || // uint8
//
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
//
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
//
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
//
src_cache.dtype());
//
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
//
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
//
int64_t block_table_stride = block_table.stride(0);
//
int64_t cache_block_stride = src_cache.stride(0);
//
int64_t cache_entry_stride = src_cache.stride(1);
//
int64_t dst_entry_stride = dst.stride(0);
//
const uint8_t* src_ptr = nullptr;
//
if (dtype == at::ScalarType::Byte) {
//
src_ptr = src_cache.data_ptr<uint8_t>();
//
} else {
//
// float8_e4m3fn or float8_e5m2
//
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
//
}
//
const int total_tokens = dst.size(0);
//
constexpr int warps_per_block = 8;
//
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
//
const int block_size_threads = warps_per_block * 32; // 256 threads
//
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
//
stream>>>(
//
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
//
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
//
static_cast<int32_t>(batch_size), block_size, total_tokens,
//
block_table_stride, cache_block_stride, cache_entry_stride,
//
dst_entry_stride);
//
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
...
...
csrc/cuda_vec_utils.cuh
View file @
82155c76
...
...
@@ -8,6 +8,8 @@
#include <cassert>
#ifdef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
...
...
csrc/torch_bindings.cpp
View file @
82155c76
...
...
@@ -799,12 +799,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"
);
cache_ops
.
impl
(
"cp_gather_cache"
,
torch
::
kCUDA
,
&
cp_gather_cache
);
cache_ops
.
def
(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()"
);
cache_ops
.
impl
(
"cp_gather_and_upconvert_fp8_kv_cache"
,
torch
::
kCUDA
,
&
cp_gather_and_upconvert_fp8_kv_cache
);
//
cache_ops.def(
//
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
//
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
//
"batch_size) -> ()");
//
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
//
&cp_gather_and_upconvert_fp8_kv_cache);
cache_ops
.
def
(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
...
...
vllm/_custom_ops.py
View file @
82155c76
...
...
@@ -2713,27 +2713,27 @@ def cp_gather_cache(
)
def
cp_gather_and_upconvert_fp8_kv_cache
(
src_cache
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
workspace_starts
:
torch
.
Tensor
,
batch_size
:
int
,
)
->
None
:
"""Gather and upconvert FP8 KV cache to BF16 workspace.
Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch
.
ops
.
_C_cache_ops
.
cp_gather_and_upconvert_fp8_kv_cache
(
src_cache
,
dst
,
block_table
,
seq_lens
,
workspace_starts
,
batch_size
)
#
def cp_gather_and_upconvert_fp8_kv_cache(
#
src_cache: torch.Tensor,
#
dst: torch.Tensor,
#
block_table: torch.Tensor,
#
seq_lens: torch.Tensor,
#
workspace_starts: torch.Tensor,
#
batch_size: int,
#
) -> None:
#
"""Gather and upconvert FP8 KV cache to BF16 workspace.
#
Args:
#
src_cache: FP8 KV cache [num_blocks, block_size, 656]
#
dst: BF16 output workspace [total_tokens, 576]
#
block_table: Block indices [num_reqs, max_blocks]
#
seq_lens: Sequence lengths [num_reqs]
#
workspace_starts: Workspace start offsets [num_reqs]
#
batch_size: Number of requests
#
"""
#
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
#
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
#
)
def
concat_mla_q
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment