Commit ba0e5fa5 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify self_multihead_attn_bias_additive_mask.

parent c3ec9351
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd( ...@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); m.def("backward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment