"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "64a83fb5882ff2a3d0e05bee5bed78281895c13b"
Commit 83181423 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Hipify self_multihead_attn

Enable HIP floa to hald conversion
parent 61416180
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
......
#include <vector> #include <vector>
#include <math.h> #include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self { namespace self {
namespace cublas_gemmex { namespace rocblas_gemm_ex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd( ...@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
); );
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemm_ex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward."); m.def("forward", &multihead_attn::self::rocblas_gemm_ex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward."); m.def("backward", &multihead_attn::self::rocblas_gemm_ex::bwd, "Self Multihead Attention Backward.");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment