Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
186aced5
Unverified
Commit
186aced5
authored
Aug 28, 2025
by
yzds
Committed by
GitHub
Aug 28, 2025
Browse files
[Kernel] cuda kernels for upcoming decode context parallel feature (#23791)
Co-authored-by:
hongchao
<
hongchao@msh.team
>
parent
daa1273b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
374 additions
and
1 deletion
+374
-1
csrc/cache.h
csrc/cache.h
+16
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+247
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+15
-0
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+72
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+24
-0
No files found.
csrc/cache.h
View file @
186aced5
...
...
@@ -36,6 +36,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
scale
);
void
cp_fused_concat_and_cache_mla
(
torch
::
Tensor
&
kv_c
,
torch
::
Tensor
&
k_pe
,
torch
::
Tensor
&
cp_local_token_select_indices
,
torch
::
Tensor
&
kv_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
scale
);
// Just for unittest
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
const
double
scale
,
const
std
::
string
&
kv_cache_dtype
);
...
...
@@ -47,4 +54,12 @@ void gather_and_maybe_dequant_cache(
torch
::
Tensor
const
&
cu_seq_lens
,
// [BATCH+1]
int64_t
batch_size
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
const
&
scale
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
);
\ No newline at end of file
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
);
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void
cp_gather_cache
(
torch
::
Tensor
const
&
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch
::
Tensor
const
&
dst
,
// [TOT_TOKENS, ENTRIES...]
torch
::
Tensor
const
&
block_table
,
// [BATCH, BLOCK_INDICES]
torch
::
Tensor
const
&
cu_seq_lens
,
// [BATCH+1]
int64_t
batch_size
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
);
csrc/cache_kernels.cu
View file @
186aced5
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include "cuda_utils.h"
#include "cuda_compat.h"
...
...
@@ -395,6 +396,51 @@ __global__ void concat_and_cache_mla_kernel(
copy
(
k_pe
,
kv_cache
,
k_pe_stride
,
block_stride
,
pe_dim
,
kv_lora_rank
);
}
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
cp_fused_concat_and_cache_mla_kernel
(
const
scalar_t
*
__restrict__
kv_c
,
// [num_full_tokens, kv_lora_rank]
const
scalar_t
*
__restrict__
k_pe
,
// [num_full_tokens, pe_dim]
const
int64_t
*
__restrict__
cp_local_token_select_indices
,
// [num_tokens]
cache_t
*
__restrict__
kv_cache
,
// [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
//
const
int
entry_stride
,
//
const
int
kv_c_stride
,
//
const
int
k_pe_stride
,
//
const
int
kv_lora_rank
,
//
const
int
pe_dim
,
//
const
int
block_size
,
//
const
float
*
scale
//
)
{
const
int64_t
token_idx
=
cp_local_token_select_indices
[
blockIdx
.
x
];
const
int64_t
slot_idx
=
slot_mapping
[
blockIdx
.
x
];
// NOTE: slot_idx can be -1 if the token is padded
if
(
slot_idx
<
0
)
{
return
;
}
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
auto
copy
=
[
&
](
const
scalar_t
*
__restrict__
src
,
cache_t
*
__restrict__
dst
,
int
src_stride
,
int
dst_stride
,
int
size
,
int
offset
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
size
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_idx
=
token_idx
*
src_stride
+
i
;
const
int64_t
dst_idx
=
block_idx
*
block_stride
+
block_offset
*
entry_stride
+
i
+
offset
;
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
dst
[
dst_idx
]
=
src
[
src_idx
];
}
else
{
dst
[
dst_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
src
[
src_idx
],
*
scale
);
}
}
};
copy
(
kv_c
,
kv_cache
,
kv_c_stride
,
block_stride
,
kv_lora_rank
,
0
);
copy
(
k_pe
,
kv_cache
,
k_pe_stride
,
block_stride
,
pe_dim
,
kv_lora_rank
);
}
}
// namespace vllm
// KV_T is the data type of key and value tensors.
...
...
@@ -508,6 +554,20 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
cp_local_token_select_indices.data_ptr<int64_t>(), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void
concat_and_cache_mla
(
torch
::
Tensor
&
kv_c
,
// [num_tokens, kv_lora_rank]
torch
::
Tensor
&
k_pe
,
// [num_tokens, pe_dim]
...
...
@@ -546,6 +606,50 @@ void concat_and_cache_mla(
CALL_CONCAT_AND_CACHE_MLA
);
}
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
// calls into one:
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
// concat_and_cache_mla.
void
cp_fused_concat_and_cache_mla
(
torch
::
Tensor
&
kv_c
,
// [num_total_tokens, kv_lora_rank]
torch
::
Tensor
&
k_pe
,
// [num_total_tokens, pe_dim]
torch
::
Tensor
&
cp_local_token_select_indices
,
// [num_tokens]
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens] or [num_actual_tokens]
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
scale
)
{
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int
num_tokens
=
slot_mapping
.
size
(
0
);
int
kv_lora_rank
=
kv_c
.
size
(
1
);
int
pe_dim
=
k_pe
.
size
(
1
);
int
block_size
=
kv_cache
.
size
(
1
);
TORCH_CHECK
(
kv_cache
.
size
(
2
)
==
kv_lora_rank
+
pe_dim
);
int
kv_c_stride
=
kv_c
.
stride
(
0
);
int
k_pe_stride
=
k_pe
.
stride
(
0
);
int
block_stride
=
kv_cache
.
stride
(
0
);
int
entry_stride
=
kv_cache
.
stride
(
1
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
kv_lora_rank
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
kv_c
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_BY_KV_CACHE_DTYPE
(
kv_c
.
dtype
(),
kv_cache_dtype
,
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA
);
}
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -779,3 +883,146 @@ void gather_and_maybe_dequant_cache(
DISPATCH_BY_KV_CACHE_DTYPE
(
dst
.
dtype
(),
kv_cache_dtype
,
CALL_GATHER_CACHE
);
}
namespace
vllm
{
template
<
typename
scalar_t
>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
__global__
void
cp_gather_cache
(
const
scalar_t
*
__restrict__
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE,
// ENTRY_SIZE]
scalar_t
*
__restrict__
dst
,
// [TOT_TOKENS, ENTRY_SIZE]
const
int32_t
*
__restrict__
block_table
,
// [BATCH, BLOCK_INDICES]
const
int32_t
*
__restrict__
cu_seq_lens
,
// [BATCH+1]
const
int32_t
block_size
,
const
int32_t
entry_size
,
const
int64_t
block_table_stride
,
const
int64_t
cache_block_stride
,
const
int64_t
cache_entry_stride
,
const
int64_t
dst_entry_stride
,
const
int32_t
*
__restrict__
seq_starts
// Optional: starting offsets per
// batch
)
{
const
int64_t
bid
=
blockIdx
.
x
;
// Batch ID
const
int32_t
num_splits
=
gridDim
.
y
;
const
int32_t
split
=
blockIdx
.
y
;
const
int32_t
seq_start
=
cu_seq_lens
[
bid
];
const
int32_t
seq_end
=
cu_seq_lens
[
bid
+
1
];
const
int32_t
seq_len
=
seq_end
-
seq_start
;
const
int32_t
tot_slots
=
seq_len
;
const
int32_t
split_slots
=
cuda_utils
::
ceil_div
(
tot_slots
,
num_splits
);
const
int32_t
split_start
=
split
*
split_slots
;
const
int32_t
split_end
=
min
((
split
+
1
)
*
split_slots
,
tot_slots
);
const
bool
is_active_split
=
(
split_start
<
tot_slots
);
const
bool
is_last_split
=
(
split_end
==
tot_slots
);
if
(
!
is_active_split
)
return
;
// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on it
const
int32_t
batch_offset
=
bid
*
block_table_stride
;
int32_t
offset
=
split_start
;
if
(
seq_starts
!=
nullptr
)
{
offset
+=
seq_starts
[
bid
];
}
int32_t
offset_div
=
offset
/
block_size
;
offset
=
offset
%
block_size
;
const
int32_t
*
batch_block_table
=
block_table
+
batch_offset
;
// Adjust dst pointer based on the cumulative sequence lengths.
dst
+=
seq_start
*
dst_entry_stride
;
auto
copy_entry
=
[
&
](
const
scalar_t
*
__restrict__
_src
,
scalar_t
*
__restrict__
_dst
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
entry_size
;
i
+=
blockDim
.
x
)
_dst
[
i
]
=
_src
[
i
];
};
for
(
int
pid
=
split_start
;
pid
<
split_end
;
++
pid
)
{
auto
block_id
=
batch_block_table
[
offset_div
];
auto
block_start_ptr
=
src_cache
+
block_id
*
cache_block_stride
;
auto
block_dst_ptr
=
dst
+
pid
*
dst_entry_stride
;
copy_entry
(
block_start_ptr
+
offset
*
cache_entry_stride
,
block_dst_ptr
);
offset
+=
1
;
// bump to next block
if
(
offset
==
block_size
)
{
offset_div
+=
1
;
offset
=
0
;
}
}
}
}
// namespace vllm
// Macro to dispatch the kernel based on the data type.
#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \
vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting slot index by
// seq_starts[bid]
void
cp_gather_cache
(
torch
::
Tensor
const
&
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch
::
Tensor
const
&
dst
,
// [TOT_TOKENS, ENTRIES...]
torch
::
Tensor
const
&
block_table
,
// [BATCH, BLOCK_INDICES]
torch
::
Tensor
const
&
cu_seq_lens
,
// [BATCH+1]
int64_t
batch_size
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
)
{
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_cache
.
device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int32_t
block_size
=
src_cache
.
size
(
1
);
int32_t
entry_size
=
src_cache
.
flatten
(
2
,
-
1
).
size
(
2
);
TORCH_CHECK
(
block_table
.
dtype
()
==
torch
::
kInt32
,
"block_table must be int32"
);
TORCH_CHECK
(
cu_seq_lens
.
dtype
()
==
torch
::
kInt32
,
"cu_seq_lens must be int32"
);
if
(
seq_starts
.
has_value
())
{
TORCH_CHECK
(
seq_starts
.
value
().
dtype
()
==
torch
::
kInt32
,
"seq_starts must be int32"
);
}
TORCH_CHECK
(
src_cache
.
device
()
==
dst
.
device
(),
"src_cache and dst must be on the same device"
);
TORCH_CHECK
(
src_cache
.
device
()
==
block_table
.
device
(),
"src_cache and block_table must be on the same device"
);
TORCH_CHECK
(
src_cache
.
device
()
==
cu_seq_lens
.
device
(),
"src_cache and cu_seq_lens must be on the same device"
);
if
(
seq_starts
.
has_value
())
{
TORCH_CHECK
(
src_cache
.
device
()
==
seq_starts
.
value
().
device
(),
"src_cache and seq_starts must be on the same device"
);
}
int64_t
block_table_stride
=
block_table
.
stride
(
0
);
int64_t
cache_block_stride
=
src_cache
.
stride
(
0
);
int64_t
cache_entry_stride
=
src_cache
.
stride
(
1
);
int64_t
dst_entry_stride
=
dst
.
stride
(
0
);
// Decide on the number of splits based on the batch size.
int
num_splits
=
batch_size
>
128
?
2
:
batch_size
>
64
?
4
:
16
;
dim3
grid
(
batch_size
,
num_splits
);
dim3
block
(
1024
);
TORCH_CHECK
(
src_cache
.
dtype
()
==
dst
.
dtype
(),
"src_cache and dst must have the same dtype"
);
const
int
dtype_bits
=
src_cache
.
element_size
()
*
8
;
const
int32_t
*
seq_starts_ptr
=
seq_starts
.
has_value
()
?
seq_starts
.
value
().
data_ptr
<
int32_t
>
()
:
nullptr
;
if
(
dtype_bits
==
32
)
{
CALL_CP_GATHER_CACHE
(
uint32_t
);
}
else
if
(
dtype_bits
==
16
)
{
CALL_CP_GATHER_CACHE
(
uint16_t
);
}
else
if
(
dtype_bits
==
8
)
{
CALL_CP_GATHER_CACHE
(
uint8_t
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type width: "
,
dtype_bits
);
}
}
csrc/torch_bindings.cpp
View file @
186aced5
...
...
@@ -686,6 +686,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()"
);
cache_ops
.
impl
(
"concat_and_cache_mla"
,
torch
::
kCUDA
,
&
concat_and_cache_mla
);
cache_ops
.
def
(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()"
);
cache_ops
.
impl
(
"cp_fused_concat_and_cache_mla"
,
torch
::
kCUDA
,
&
cp_fused_concat_and_cache_mla
);
// Convert the key and value cache to fp8 data type.
cache_ops
.
def
(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
...
...
@@ -702,6 +712,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale, Tensor? seq_starts) -> ()"
);
cache_ops
.
impl
(
"gather_and_maybe_dequant_cache"
,
torch
::
kCUDA
,
&
gather_and_maybe_dequant_cache
);
cache_ops
.
def
(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"
);
cache_ops
.
impl
(
"cp_gather_cache"
,
torch
::
kCUDA
,
&
cp_gather_cache
);
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cuda_utils
),
cuda_utils
)
{
...
...
tests/kernels/attention/test_cache.py
View file @
186aced5
...
...
@@ -790,6 +790,78 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
torch
.
testing
.
assert_close
(
dst
,
expected
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_cp_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
max_seq_len
+
1
,
(
batch_size
,
),
device
=
device
)
total_tokens
=
seq_len_tensor
.
sum
()
cu_seq_lens
=
torch
.
empty
((
batch_size
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
cu_seq_lens
[
0
]
=
0
cu_seq_lens
[
1
:]
=
seq_len_tensor
.
cumsum
(
dim
=
0
).
to
(
dtype
=
torch
.
int32
)
print
(
"seq_len_tensor"
,
seq_len_tensor
)
tot_blocks_tensor
=
(
seq_len_tensor
+
block_size
-
1
)
//
block_size
block_table
=
torch
.
empty
((
batch_size
,
num_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
batch_size
):
perm
=
torch
.
randperm
(
num_blocks
,
device
=
device
)
block_table
[
b
,
:]
=
perm
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
src_cache
.
dtype
,
device
=
device
)
expected_batches
=
[]
for
b
in
range
(
batch_size
):
s
=
seq_len_tensor
[
b
]
if
s
==
0
:
continue
tot
=
tot_blocks_tensor
[
b
]
blocks
=
block_table
[
b
,
:
tot
].
tolist
()
gathered_rows
=
[]
for
i
in
range
(
tot
-
1
):
gathered_rows
.
append
(
src_cache
[
blocks
[
i
]])
remaining
=
s
-
(
tot
-
1
)
*
block_size
gathered_rows
.
append
(
src_cache
[
blocks
[
-
1
],
:
remaining
,
:])
batch_expected
=
torch
.
cat
(
gathered_rows
,
dim
=
0
)
expected_batches
.
append
(
batch_expected
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
cp_gather_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
cp_gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
KV_LORA_RANKS
)
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
QK_ROPE_HEAD_DIMS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS_MLA
)
...
...
vllm/_custom_ops.py
View file @
186aced5
...
...
@@ -1625,6 +1625,20 @@ def concat_and_cache_mla(
scale
)
def
cp_fused_concat_and_cache_mla
(
kv_c
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
cp_local_token_select_indices
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
scale
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
cp_fused_concat_and_cache_mla
(
kv_c
,
k_pe
,
cp_local_token_select_indices
,
kv_cache
,
slot_mapping
,
kv_cache_dtype
,
scale
)
def
copy_blocks
(
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
...
...
@@ -1662,6 +1676,16 @@ def gather_and_maybe_dequant_cache(
scale
,
seq_starts
)
def
cp_gather_cache
(
src_cache
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
batch_size
:
int
,
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
cp_gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
return
torch
.
ops
.
_C_cuda_utils
.
get_device_attribute
(
attribute
,
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment