"vscode:/vscode.git/clone" did not exist on "4eade17a6e1baf6bc5c71daac7fc3ac595c378a2"
Commit 2d465ec7 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip cutlass_mla_decode

parent 081057de
......@@ -130,11 +130,11 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table, double scale);
// void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
// torch::Tensor const& q_pe,
// torch::Tensor const& kv_c_and_k_pe_cache,
// torch::Tensor const& seq_lens,
// torch::Tensor const& page_table, double scale);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
......
......@@ -131,11 +131,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Compute MLA decode using cutlass.
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// ops.def(
// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
// " Tensor page_table, float scale) -> ()");
// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
......
......@@ -1533,10 +1533,10 @@ def flash_mla_with_kvcache(
return out, softmax_lse
def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
scale: float) -> torch.Tensor:
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table, scale)
return out
# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
# seq_lens: torch.Tensor, page_table: torch.Tensor,
# scale: float) -> torch.Tensor:
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment