Commit 7f7894c0 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip concat_and_cache_mla_rope_fused

parent 7e63ef82
......@@ -287,7 +287,7 @@ endif()
set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/cache_kernels.cu"
"csrc/cache_kernels_fused.cu"
# "csrc/cache_kernels_fused.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
......
......@@ -34,11 +34,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& scale);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
// void concat_and_cache_mla_rope_fused(
// torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
// torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
// torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
// const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
......
......@@ -792,20 +792,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
// Rotate Q and K, then write to kv cache for MLA
cache_ops.def(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()");
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
&concat_and_cache_mla_rope_fused);
// cache_ops.def(
// "concat_and_cache_mla_rope_fused("
// " Tensor positions,"
// " Tensor! q_pe,"
// " Tensor! k_pe,"
// " Tensor kv_c,"
// " Tensor cos_sin_cache,"
// " bool is_neox,"
// " Tensor slot_mapping,"
// " Tensor! kv_cache,"
// " str kv_cache_dtype,"
// " Tensor kv_cache_scale) -> ()");
// cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
// &concat_and_cache_mla_rope_fused);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
......
......@@ -2636,30 +2636,30 @@ def concat_and_cache_mla(
)
def concat_and_cache_mla_rope_fused(
positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
kv_c: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
kv_cache_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused(
positions,
q_pe,
k_pe,
kv_c,
cos_sin_cache,
is_neox,
slot_mapping,
kv_cache,
kv_cache_dtype,
kv_cache_scale,
)
# def concat_and_cache_mla_rope_fused(
# positions: torch.Tensor,
# q_pe: torch.Tensor,
# k_pe: torch.Tensor,
# kv_c: torch.Tensor,
# cos_sin_cache: torch.Tensor,
# is_neox: bool,
# slot_mapping: torch.Tensor,
# kv_cache: torch.Tensor,
# kv_cache_dtype: str,
# kv_cache_scale: torch.Tensor,
# ) -> None:
# torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused(
# positions,
# q_pe,
# k_pe,
# kv_c,
# cos_sin_cache,
# is_neox,
# slot_mapping,
# kv_cache,
# kv_cache_dtype,
# kv_cache_scale,
# )
def swap_blocks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment