Commit 2a77b772 authored by huchen1's avatar huchen1
Browse files

add fastmoe support rocm4.0.1

parent d2392de2
...@@ -27,7 +27,12 @@ void fmoe_cuda_assign_pos_impl( ...@@ -27,7 +27,12 @@ void fmoe_cuda_assign_pos_impl(
} }
#define PERTHREAD_EXPERTS 256 #define PERTHREAD_EXPERTS 256
#ifdef MOE_HIP_DIFF
#define WARP_SIZE 64
#else
#define WARP_SIZE 32 #define WARP_SIZE 32
#endif
__global__ __global__
void expert_count_kernel(const long* gate_idx, int* expert_count, void expert_count_kernel(const long* gate_idx, int* expert_count,
...@@ -52,7 +57,11 @@ void expert_count_kernel(const long* gate_idx, int* expert_count, ...@@ -52,7 +57,11 @@ void expert_count_kernel(const long* gate_idx, int* expert_count,
int x = res_tmp[i - expert_min]; int x = res_tmp[i - expert_min];
#pragma unroll #pragma unroll
for (int j = 1; j < WARP_SIZE; j <<= 1) { for (int j = 1; j < WARP_SIZE; j <<= 1) {
#ifdef MOE_HIP_DIFF
x = x + __shfl_down(x, j);
#else
x = x + __shfl_down_sync(-1u, x, j); x = x + __shfl_down_sync(-1u, x, j);
#endif
} }
if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x % WARP_SIZE == 0) {
atomicAdd(expert_count + i, x); atomicAdd(expert_count + i, x);
......
...@@ -39,7 +39,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, ...@@ -39,7 +39,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const __half *beta, const __half *beta,
__half *Carray[], int ldc, __half *Carray[], int ldc,
int batchCount) { int batchCount) {
#ifdef MOE_HIP_DIFF
return rocblas_hgemm_batched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
#endif
} }
...@@ -73,7 +77,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle, ...@@ -73,7 +77,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const __half *B, int ldb, const __half *B, int ldb,
const __half *beta, const __half *beta,
__half *C, int ldc) { __half *C, int ldc) {
#ifdef MOE_HIP_DIFF
return rocblas_hgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
#endif
} }
inline cublasStatus_t cublasXgemm(cublasHandle_t handle, inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
...@@ -84,12 +92,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle, ...@@ -84,12 +92,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const c10::Half *B, int ldb, const c10::Half *B, int ldb,
const c10::Half *beta, const c10::Half *beta,
c10::Half *C, int ldc) { c10::Half *C, int ldc) {
#ifdef MOE_HIP_DIFF
return rocblas_hgemm(handle, transa, transb, m, n, k,
(const rocblas_half*)alpha,
(const rocblas_half*)A, lda,
(const rocblas_half*)B, ldb,
(const rocblas_half*)beta,
(rocblas_half*)C, ldc);
#else
return cublasHgemm(handle, transa, transb, m, n, k, return cublasHgemm(handle, transa, transb, m, n, k,
(const __half*)alpha, (const __half*)alpha,
(const __half*)A, lda, (const __half*)A, lda,
(const __half*)B, ldb, (const __half*)B, ldb,
(const __half*)beta, (const __half*)beta,
(__half*)C, ldc); (__half*)C, ldc);
#endif
} }
#endif // CUBLAS_WRAPPER_H #endif // CUBLAS_WRAPPER_H
...@@ -51,6 +51,53 @@ static const char *_cudaGetErrorEnum(CUresult error) { ...@@ -51,6 +51,53 @@ static const char *_cudaGetErrorEnum(CUresult error) {
} }
#endif #endif
#ifdef MOE_HIP_DIFF
static const char *_cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
case rocblas_status_success:
return "rocblas_status_success";
case rocblas_status_invalid_handle:
return "rocblas_status_invalid_handle";
case rocblas_status_not_implemented:
return "rocblas_status_not_implemented";
case rocblas_status_invalid_pointer:
return "rocblas_status_invalid_pointer:";
case rocblas_status_invalid_size:
return "rocblas_status_invalid_size";
case rocblas_status_memory_error:
return "rocblas_status_memory_error";
case rocblas_status_internal_error:
return "rocblas_status_internal_error";
case rocblas_status_perf_degraded:
return "rocblas_status_perf_degraded";
case rocblas_status_size_query_mismatch:
return "rocblas_status_size_query_mismatch";
case rocblas_status_size_increased:
return "rocblas_status_size_increased";
case rocblas_status_size_unchanged:
return "rocblas_status_size_unchanged";
case rocblas_status_invalid_value:
return "rocblas_status_invalid_value";
case rocblas_status_continue:
return "rocblas_status_continue";
}
return "<unknown>";
}
#else
// cuBLAS API errors // cuBLAS API errors
static const char *_cudaGetErrorEnum(cublasStatus_t error) { static const char *_cudaGetErrorEnum(cublasStatus_t error) {
switch (error) { switch (error) {
...@@ -87,6 +134,7 @@ static const char *_cudaGetErrorEnum(cublasStatus_t error) { ...@@ -87,6 +134,7 @@ static const char *_cudaGetErrorEnum(cublasStatus_t error) {
return "<unknown>"; return "<unknown>";
} }
#endif
#ifdef _CUFFT_H_ #ifdef _CUFFT_H_
// cuFFT API errors // cuFFT API errors
......
...@@ -17,7 +17,15 @@ authors = [ ...@@ -17,7 +17,15 @@ authors = [
if os.environ.get('USE_NCCL', '1') == '1': if os.environ.get('USE_NCCL', '1') == '1':
cxx_flags.append('-DFMOE_USE_NCCL') cxx_flags.append('-DFMOE_USE_NCCL')
ext_libs.append('nccl') if os.environ.get('USE_ROCM', '0') == '1':
ext_libs.append('rccl')
else:
ext_libs.append('nccl')
if os.environ.get('USE_ROCM', '0') == '1':
define_macros=[('MOE_HIP_DIFF', None)]
else:
define_macros=[]
if __name__ == '__main__': if __name__ == '__main__':
...@@ -41,6 +49,7 @@ if __name__ == '__main__': ...@@ -41,6 +49,7 @@ if __name__ == '__main__':
'cuda/parallel_linear.cu', 'cuda/parallel_linear.cu',
'cuda/fmoe_cuda.cpp', 'cuda/fmoe_cuda.cpp',
], ],
define_macros=define_macros,
extra_compile_args={ extra_compile_args={
'cxx': cxx_flags, 'cxx': cxx_flags,
'nvcc': cxx_flags 'nvcc': cxx_flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment