Commit 8bdbb502 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify encdec_multihead_attn

parent ba0e5fa5
......@@ -3,7 +3,7 @@
namespace multihead_attn {
namespace encdec {
namespace cublas_gemmex {
namespace rocblas_gemm_ex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
......@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
);
}
} // end namespace cublas_gemmex
} // end namespace rocblas_gemm_ex
} // end namespace encdec
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward.");
m.def("forward", &multihead_attn::encdec::rocblas_gemm_ex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::rocblas_gemm_ex::bwd, "Encdec Multihead Attention Backward.");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment