Commit 8bdbb502 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify encdec_multihead_attn

parent ba0e5fa5
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace cublas_gemmex { namespace rocblas_gemm_ex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd( ...@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemm_ex
} // end namespace encdec } // end namespace encdec
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward."); m.def("forward", &multihead_attn::encdec::rocblas_gemm_ex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward."); m.def("backward", &multihead_attn::encdec::rocblas_gemm_ex::bwd, "Encdec Multihead Attention Backward.");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment