Commit 7f7894c0 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip concat_and_cache_mla_rope_fused

parent 7e63ef82
...@@ -287,7 +287,7 @@ endif() ...@@ -287,7 +287,7 @@ endif()
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu" "csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/cache_kernels.cu" "csrc/cache_kernels.cu"
"csrc/cache_kernels_fused.cu" # "csrc/cache_kernels_fused.cu"
"csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu" "csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu" "csrc/attention/merge_attn_states.cu"
......
...@@ -34,11 +34,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, ...@@ -34,11 +34,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& scale); torch::Tensor& scale);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla // NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void concat_and_cache_mla_rope_fused( // void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe, // torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox, // torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache, // torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale); // const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
......
...@@ -792,20 +792,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -792,20 +792,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
// Rotate Q and K, then write to kv cache for MLA // Rotate Q and K, then write to kv cache for MLA
cache_ops.def( // cache_ops.def(
"concat_and_cache_mla_rope_fused(" // "concat_and_cache_mla_rope_fused("
" Tensor positions," // " Tensor positions,"
" Tensor! q_pe," // " Tensor! q_pe,"
" Tensor! k_pe," // " Tensor! k_pe,"
" Tensor kv_c," // " Tensor kv_c,"
" Tensor cos_sin_cache," // " Tensor cos_sin_cache,"
" bool is_neox," // " bool is_neox,"
" Tensor slot_mapping," // " Tensor slot_mapping,"
" Tensor! kv_cache," // " Tensor! kv_cache,"
" str kv_cache_dtype," // " str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()"); // " Tensor kv_cache_scale) -> ()");
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA, // cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
&concat_and_cache_mla_rope_fused); // &concat_and_cache_mla_rope_fused);
// Convert the key and value cache to fp8 data type. // Convert the key and value cache to fp8 data type.
cache_ops.def( cache_ops.def(
......
...@@ -2636,30 +2636,30 @@ def concat_and_cache_mla( ...@@ -2636,30 +2636,30 @@ def concat_and_cache_mla(
) )
def concat_and_cache_mla_rope_fused( # def concat_and_cache_mla_rope_fused(
positions: torch.Tensor, # positions: torch.Tensor,
q_pe: torch.Tensor, # q_pe: torch.Tensor,
k_pe: torch.Tensor, # k_pe: torch.Tensor,
kv_c: torch.Tensor, # kv_c: torch.Tensor,
cos_sin_cache: torch.Tensor, # cos_sin_cache: torch.Tensor,
is_neox: bool, # is_neox: bool,
slot_mapping: torch.Tensor, # slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, # kv_cache: torch.Tensor,
kv_cache_dtype: str, # kv_cache_dtype: str,
kv_cache_scale: torch.Tensor, # kv_cache_scale: torch.Tensor,
) -> None: # ) -> None:
torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused( # torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused(
positions, # positions,
q_pe, # q_pe,
k_pe, # k_pe,
kv_c, # kv_c,
cos_sin_cache, # cos_sin_cache,
is_neox, # is_neox,
slot_mapping, # slot_mapping,
kv_cache, # kv_cache,
kv_cache_dtype, # kv_cache_dtype,
kv_cache_scale, # kv_cache_scale,
) # )
def swap_blocks( def swap_blocks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment