Commit 2a77b772 authored by huchen1's avatar huchen1
Browse files

add fastmoe support rocm4.0.1

parent d2392de2
......@@ -27,7 +27,12 @@ void fmoe_cuda_assign_pos_impl(
}
#define PERTHREAD_EXPERTS 256
#ifdef MOE_HIP_DIFF
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
__global__
void expert_count_kernel(const long* gate_idx, int* expert_count,
......@@ -52,7 +57,11 @@ void expert_count_kernel(const long* gate_idx, int* expert_count,
int x = res_tmp[i - expert_min];
#pragma unroll
for (int j = 1; j < WARP_SIZE; j <<= 1) {
#ifdef MOE_HIP_DIFF
x = x + __shfl_down(x, j);
#else
x = x + __shfl_down_sync(-1u, x, j);
#endif
}
if (threadIdx.x % WARP_SIZE == 0) {
atomicAdd(expert_count + i, x);
......
......@@ -39,7 +39,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const __half *beta,
__half *Carray[], int ldc,
int batchCount) {
#ifdef MOE_HIP_DIFF
return rocblas_hgemm_batched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
#endif
}
......@@ -73,7 +77,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const __half *B, int ldb,
const __half *beta,
__half *C, int ldc) {
#ifdef MOE_HIP_DIFF
return rocblas_hgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
#endif
}
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
......@@ -84,12 +92,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const c10::Half *B, int ldb,
const c10::Half *beta,
c10::Half *C, int ldc) {
#ifdef MOE_HIP_DIFF
return rocblas_hgemm(handle, transa, transb, m, n, k,
(const rocblas_half*)alpha,
(const rocblas_half*)A, lda,
(const rocblas_half*)B, ldb,
(const rocblas_half*)beta,
(rocblas_half*)C, ldc);
#else
return cublasHgemm(handle, transa, transb, m, n, k,
(const __half*)alpha,
(const __half*)A, lda,
(const __half*)B, ldb,
(const __half*)beta,
(__half*)C, ldc);
#endif
}
#endif // CUBLAS_WRAPPER_H
......@@ -51,6 +51,53 @@ static const char *_cudaGetErrorEnum(CUresult error) {
}
#endif
#ifdef MOE_HIP_DIFF
static const char *_cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
case rocblas_status_success:
return "rocblas_status_success";
case rocblas_status_invalid_handle:
return "rocblas_status_invalid_handle";
case rocblas_status_not_implemented:
return "rocblas_status_not_implemented";
case rocblas_status_invalid_pointer:
return "rocblas_status_invalid_pointer:";
case rocblas_status_invalid_size:
return "rocblas_status_invalid_size";
case rocblas_status_memory_error:
return "rocblas_status_memory_error";
case rocblas_status_internal_error:
return "rocblas_status_internal_error";
case rocblas_status_perf_degraded:
return "rocblas_status_perf_degraded";
case rocblas_status_size_query_mismatch:
return "rocblas_status_size_query_mismatch";
case rocblas_status_size_increased:
return "rocblas_status_size_increased";
case rocblas_status_size_unchanged:
return "rocblas_status_size_unchanged";
case rocblas_status_invalid_value:
return "rocblas_status_invalid_value";
case rocblas_status_continue:
return "rocblas_status_continue";
}
return "<unknown>";
}
#else
// cuBLAS API errors
static const char *_cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
......@@ -87,6 +134,7 @@ static const char *_cudaGetErrorEnum(cublasStatus_t error) {
return "<unknown>";
}
#endif
#ifdef _CUFFT_H_
// cuFFT API errors
......
......@@ -17,8 +17,16 @@ authors = [
if os.environ.get('USE_NCCL', '1') == '1':
cxx_flags.append('-DFMOE_USE_NCCL')
if os.environ.get('USE_ROCM', '0') == '1':
ext_libs.append('rccl')
else:
ext_libs.append('nccl')
if os.environ.get('USE_ROCM', '0') == '1':
define_macros=[('MOE_HIP_DIFF', None)]
else:
define_macros=[]
if __name__ == '__main__':
setuptools.setup(
......@@ -41,6 +49,7 @@ if __name__ == '__main__':
'cuda/parallel_linear.cu',
'cuda/fmoe_cuda.cpp',
],
define_macros=define_macros,
extra_compile_args={
'cxx': cxx_flags,
'nvcc': cxx_flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment