Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ec64732
Commit
4ec64732
authored
Jan 06, 2026
by
zhuwenwen
Browse files
add indexer_k_cache_kernel
parent
25ec6a34
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
148 additions
and
0 deletions
+148
-0
csrc/cache.h
csrc/cache.h
+7
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+119
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+8
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+7
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+7
-0
No files found.
csrc/cache.h
View file @
4ec64732
...
@@ -83,6 +83,13 @@ void indexer_k_quant_and_cache(
...
@@ -83,6 +83,13 @@ void indexer_k_quant_and_cache(
int64_t
quant_block_size
,
// quantization block size
int64_t
quant_block_size
,
// quantization block size
const
std
::
string
&
scale_fmt
);
const
std
::
string
&
scale_fmt
);
void
indexer_k_cache
(
torch
::
Tensor
&
k
,
// [num_tokens, head_dim]
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, cache_stride]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
scale_fmt
);
// Extract function to gather quantized K cache
// Extract function to gather quantized K cache
void
cp_gather_indexer_k_quant_cache
(
void
cp_gather_indexer_k_quant_cache
(
const
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, cache_stride]
const
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, cache_stride]
...
...
csrc/cache_kernels.cu
View file @
4ec64732
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <cfloat>
#include <cfloat>
#include <map>
#include <map>
#include <vector>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -808,6 +809,52 @@ __global__ void indexer_k_quant_and_cache_kernel(
...
@@ -808,6 +809,52 @@ __global__ void indexer_k_quant_and_cache_kernel(
}
}
}
}
template
<
typename
scalar_t
,
typename
cache_t
>
__global__
void
indexer_k_cache_kernel
(
const
scalar_t
*
__restrict__
k
,
// [num_tokens, head_dim]
cache_t
*
__restrict__
kv_cache
,
// [num_blocks, block_size, cache_stride]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
head_dim
,
// dimension of each head
const
int
cache_block_size
,
// cache block size
const
int
cache_stride
// stride for each token in kv_cache
)
{
constexpr
int
VEC_SIZE
=
4
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
head_dim_idx
=
(
blockIdx
.
y
*
blockDim
.
y
*
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
)
*
VEC_SIZE
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
const
int64_t
block_idx
=
slot_idx
/
cache_block_size
;
const
int64_t
block_offset
=
slot_idx
%
cache_block_size
;
// NOTE: slot_idx can be -1 if the token is padded
if
(
slot_idx
<
0
||
(
head_dim_idx
>=
head_dim
))
{
return
;
}
float2
k_val
=
(
reinterpret_cast
<
const
float2
*>
(
k
))[(
token_idx
*
head_dim
+
head_dim_idx
)
/
VEC_SIZE
];
scalar_t
*
k_val_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
k_val
);
const
int64_t
dst_offset
=
block_idx
*
cache_block_size
*
cache_stride
+
block_offset
*
head_dim
+
head_dim_idx
;
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
i
++
)
{
float
val
=
static_cast
<
float
>
(
k_val_ptr
[
i
]);
if
constexpr
(
std
::
is_same
<
cache_t
,
at
::
Half
>::
value
||
std
::
is_same
<
cache_t
,
__half
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
__float2half
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
cache_t
,
at
::
BFloat16
>::
value
||
std
::
is_same
<
cache_t
,
__nv_bfloat16
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
__float2bfloat16
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
cache_t
,
float
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
val
;
}
else
{
kv_cache
[
dst_offset
+
i
]
=
static_cast
<
cache_t
>
(
val
);
}
}
}
template
<
int
BLOCK_Y_SIZE
>
template
<
int
BLOCK_Y_SIZE
>
__global__
void
cp_gather_indexer_k_quant_cache_kernel
(
__global__
void
cp_gather_indexer_k_quant_cache_kernel
(
const
char
*
__restrict__
kv_cache
,
// [num_blocks, block_size,
const
char
*
__restrict__
kv_cache
,
// [num_blocks, block_size,
...
@@ -1791,6 +1838,78 @@ void indexer_k_quant_and_cache(
...
@@ -1791,6 +1838,78 @@ void indexer_k_quant_and_cache(
CALL_INDEXER_K_QUANT_AND_CACHE
);
CALL_INDEXER_K_QUANT_AND_CACHE
);
}
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_CACHE(KV_T, CACHE_T) \
vllm::indexer_k_cache_kernel<KV_T, CACHE_T> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, \
cache_block_size, cache_stride);
void
indexer_k_cache
(
torch
::
Tensor
&
k
,
// [num_tokens, head_dim]
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, cache_stride]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
scale_fmt
)
{
int
num_tokens
=
k
.
size
(
0
);
int
head_dim
=
k
.
size
(
1
);
int
cache_block_size
=
kv_cache
.
size
(
1
);
int
cache_stride
=
kv_cache
.
size
(
2
);
bool
use_ue8m0
=
scale_fmt
==
"ue8m0"
;
TORCH_CHECK
(
k
.
device
()
==
kv_cache
.
device
(),
"k and kv_cache must be on the same device"
);
TORCH_CHECK
(
k
.
device
()
==
slot_mapping
.
device
(),
"k and slot_mapping must be on the same device"
);
constexpr
int
vec_size
=
4
;
dim3
grid
(
num_tokens
,
(
head_dim
+
vec_size
-
1
)
/
vec_size
);
dim3
block
(
32
,
vec_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
k
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
k
.
scalar_type
(),
"indexer_k_cache"
,
([
&
]
{
using
k_t
=
scalar_t
;
auto
kv_cache_type
=
kv_cache
.
scalar_type
();
if
(
kv_cache_type
==
at
::
ScalarType
::
Float
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
float
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
if
(
kv_cache_type
==
at
::
ScalarType
::
Half
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
at
::
Half
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
at
::
Half
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
if
(
kv_cache_type
==
at
::
ScalarType
::
BFloat16
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
at
::
BFloat16
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
at
::
BFloat16
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported kv_cache dtype: "
,
kv_cache
.
dtype
());
}
}));
}
// Macro to dispatch the kernel based on the data amount.
// Macro to dispatch the kernel based on the data amount.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
...
...
csrc/torch_bindings.cpp
View file @
4ec64732
...
@@ -822,6 +822,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
...
@@ -822,6 +822,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops
.
impl
(
"indexer_k_quant_and_cache"
,
torch
::
kCUDA
,
cache_ops
.
impl
(
"indexer_k_quant_and_cache"
,
torch
::
kCUDA
,
&
indexer_k_quant_and_cache
);
&
indexer_k_quant_and_cache
);
cache_ops
.
def
(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"indexer_k_cache"
,
torch
::
kCUDA
,
&
indexer_k_cache
);
cache_ops
.
def
(
cache_ops
.
def
(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"
);
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"
);
...
...
vllm/_custom_ops.py
View file @
4ec64732
...
@@ -2616,6 +2616,13 @@ def indexer_k_quant_and_cache(
...
@@ -2616,6 +2616,13 @@ def indexer_k_quant_and_cache(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
kv_cache_dtype
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
kv_cache_dtype
)
)
def
indexer_k_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
indexer_k_cache
(
k
,
kv_cache
,
slot_mapping
,
kv_cache_dtype
)
def
cp_gather_indexer_k_quant_cache
(
def
cp_gather_indexer_k_quant_cache
(
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
4ec64732
...
@@ -682,6 +682,13 @@ def sparse_attn_indexer(
...
@@ -682,6 +682,13 @@ def sparse_attn_indexer(
quant_block_size
,
quant_block_size
,
scale_fmt
,
scale_fmt
,
)
)
else
:
ops
.
indexer_k_cache
(
k
,
kv_cache
,
slot_mapping
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
if
has_prefill
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment