Unverified Commit 06ffc7e1 authored by TY-AMD's avatar TY-AMD Committed by GitHub
Browse files

[Misc][ROCm] Exclude `cutlass_mla_decode` for ROCm build (#17289)


Signed-off-by: default avatarTianyuan Wu <Tianyuan.Wu@amd.com>
parent d3cf61b8
...@@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()"); ") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Compute MLA decode using cutlass.
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm // Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( ops.def(
...@@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
// CUTLASS MLA decode
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Mamba selective scan kernel // Mamba selective scan kernel
ops.def( ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta," "selective_scan_fwd(Tensor! u, Tensor! delta,"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment