Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b09a0d7b
Commit
b09a0d7b
authored
Jan 21, 2026
by
zhuwenwen
Browse files
skip concat_and_cache_mla_rope_fused
parent
5b1db8b2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
44 deletions
+44
-44
CMakeLists.txt
CMakeLists.txt
+1
-1
csrc/cache.h
csrc/cache.h
+5
-5
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+14
-14
vllm/_custom_ops.py
vllm/_custom_ops.py
+24
-24
No files found.
CMakeLists.txt
View file @
b09a0d7b
...
...
@@ -282,7 +282,7 @@ endif()
set
(
VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/cache_kernels.cu"
"csrc/cache_kernels_fused.cu"
#
"csrc/cache_kernels_fused.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
...
...
csrc/cache.h
View file @
b09a0d7b
...
...
@@ -28,11 +28,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch
::
Tensor
&
scale
);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void
concat_and_cache_mla_rope_fused
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
q_pe
,
torch
::
Tensor
&
k_pe
,
torch
::
Tensor
&
kv_c
,
torch
::
Tensor
&
rope_cos_sin_cache
,
bool
rope_is_neox
,
torch
::
Tensor
&
kv_cache_slot_mapping
,
torch
::
Tensor
&
kv_cache
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
kv_cache_quant_scale
);
//
void concat_and_cache_mla_rope_fused(
//
torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
//
torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
//
torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
//
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
// Just for unittest
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
...
...
csrc/torch_bindings.cpp
View file @
b09a0d7b
...
...
@@ -742,20 +742,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops
.
impl
(
"concat_and_cache_mla"
,
torch
::
kCUDA
,
&
concat_and_cache_mla
);
// Rotate Q and K, then write to kv cache for MLA
cache_ops
.
def
(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()"
);
cache_ops
.
impl
(
"concat_and_cache_mla_rope_fused"
,
torch
::
kCUDA
,
&
concat_and_cache_mla_rope_fused
);
//
cache_ops.def(
//
"concat_and_cache_mla_rope_fused("
//
" Tensor positions,"
//
" Tensor! q_pe,"
//
" Tensor! k_pe,"
//
" Tensor kv_c,"
//
" Tensor cos_sin_cache,"
//
" bool is_neox,"
//
" Tensor slot_mapping,"
//
" Tensor! kv_cache,"
//
" str kv_cache_dtype,"
//
" Tensor kv_cache_scale) -> ()");
//
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
//
&concat_and_cache_mla_rope_fused);
// Convert the key and value cache to fp8 data type.
cache_ops
.
def
(
...
...
vllm/_custom_ops.py
View file @
b09a0d7b
...
...
@@ -2423,30 +2423,30 @@ def concat_and_cache_mla(
)
def
concat_and_cache_mla_rope_fused
(
positions
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
slot_mapping
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_scale
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
concat_and_cache_mla_rope_fused
(
positions
,
q_pe
,
k_pe
,
kv_c
,
cos_sin_cache
,
is_neox
,
slot_mapping
,
kv_cache
,
kv_cache_dtype
,
kv_cache_scale
,
)
#
def concat_and_cache_mla_rope_fused(
#
positions: torch.Tensor,
#
q_pe: torch.Tensor,
#
k_pe: torch.Tensor,
#
kv_c: torch.Tensor,
#
cos_sin_cache: torch.Tensor,
#
is_neox: bool,
#
slot_mapping: torch.Tensor,
#
kv_cache: torch.Tensor,
#
kv_cache_dtype: str,
#
kv_cache_scale: torch.Tensor,
#
) -> None:
#
torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused(
#
positions,
#
q_pe,
#
k_pe,
#
kv_c,
#
cos_sin_cache,
#
is_neox,
#
slot_mapping,
#
kv_cache,
#
kv_cache_dtype,
#
kv_cache_scale,
#
)
def
swap_blocks
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment