Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2a77b772
Commit
2a77b772
authored
Sep 30, 2021
by
huchen1
Browse files
add fastmoe support rocm4.0.1
parent
d2392de2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
1 deletion
+84
-1
cuda/local_exchange.cuh
cuda/local_exchange.cuh
+9
-0
cuda/utils/cublas_wrapper.h
cuda/utils/cublas_wrapper.h
+17
-0
cuda/utils/helper_cuda.h
cuda/utils/helper_cuda.h
+48
-0
setup.py
setup.py
+10
-1
No files found.
cuda/local_exchange.cuh
View file @
2a77b772
...
@@ -27,7 +27,12 @@ void fmoe_cuda_assign_pos_impl(
...
@@ -27,7 +27,12 @@ void fmoe_cuda_assign_pos_impl(
}
}
#define PERTHREAD_EXPERTS 256
#define PERTHREAD_EXPERTS 256
#ifdef MOE_HIP_DIFF
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#define WARP_SIZE 32
#endif
__global__
__global__
void
expert_count_kernel
(
const
long
*
gate_idx
,
int
*
expert_count
,
void
expert_count_kernel
(
const
long
*
gate_idx
,
int
*
expert_count
,
...
@@ -52,7 +57,11 @@ void expert_count_kernel(const long* gate_idx, int* expert_count,
...
@@ -52,7 +57,11 @@ void expert_count_kernel(const long* gate_idx, int* expert_count,
int
x
=
res_tmp
[
i
-
expert_min
];
int
x
=
res_tmp
[
i
-
expert_min
];
#pragma unroll
#pragma unroll
for
(
int
j
=
1
;
j
<
WARP_SIZE
;
j
<<=
1
)
{
for
(
int
j
=
1
;
j
<
WARP_SIZE
;
j
<<=
1
)
{
#ifdef MOE_HIP_DIFF
x
=
x
+
__shfl_down
(
x
,
j
);
#else
x
=
x
+
__shfl_down_sync
(
-
1u
,
x
,
j
);
x
=
x
+
__shfl_down_sync
(
-
1u
,
x
,
j
);
#endif
}
}
if
(
threadIdx
.
x
%
WARP_SIZE
==
0
)
{
if
(
threadIdx
.
x
%
WARP_SIZE
==
0
)
{
atomicAdd
(
expert_count
+
i
,
x
);
atomicAdd
(
expert_count
+
i
,
x
);
...
...
cuda/utils/cublas_wrapper.h
View file @
2a77b772
...
@@ -39,7 +39,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
...
@@ -39,7 +39,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const
__half
*
beta
,
const
__half
*
beta
,
__half
*
Carray
[],
int
ldc
,
__half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
int
batchCount
)
{
#ifdef MOE_HIP_DIFF
return
rocblas_hgemm_batched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
rocblas_half
*
)
alpha
,
(
const
rocblas_half
*
const
*
)
Aarray
,
lda
,
(
const
rocblas_half
*
const
*
)
Barray
,
ldb
,
(
const
rocblas_half
*
)
beta
,
(
rocblas_half
*
const
*
)
Carray
,
ldc
,
batchCount
);
#else
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
#endif
}
}
...
@@ -73,7 +77,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
...
@@ -73,7 +77,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const
__half
*
B
,
int
ldb
,
const
__half
*
B
,
int
ldb
,
const
__half
*
beta
,
const
__half
*
beta
,
__half
*
C
,
int
ldc
)
{
__half
*
C
,
int
ldc
)
{
#ifdef MOE_HIP_DIFF
return
rocblas_hgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
rocblas_half
*
)
alpha
,
(
const
rocblas_half
*
)
A
,
lda
,
(
const
rocblas_half
*
)
B
,
ldb
,
(
const
rocblas_half
*
)
beta
,
(
rocblas_half
*
)
C
,
ldc
);
#else
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
#endif
}
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
...
@@ -84,12 +92,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
...
@@ -84,12 +92,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const
c10
::
Half
*
B
,
int
ldb
,
const
c10
::
Half
*
B
,
int
ldb
,
const
c10
::
Half
*
beta
,
const
c10
::
Half
*
beta
,
c10
::
Half
*
C
,
int
ldc
)
{
c10
::
Half
*
C
,
int
ldc
)
{
#ifdef MOE_HIP_DIFF
return
rocblas_hgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
rocblas_half
*
)
alpha
,
(
const
rocblas_half
*
)
A
,
lda
,
(
const
rocblas_half
*
)
B
,
ldb
,
(
const
rocblas_half
*
)
beta
,
(
rocblas_half
*
)
C
,
ldc
);
#else
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
__half
*
)
alpha
,
(
const
__half
*
)
alpha
,
(
const
__half
*
)
A
,
lda
,
(
const
__half
*
)
A
,
lda
,
(
const
__half
*
)
B
,
ldb
,
(
const
__half
*
)
B
,
ldb
,
(
const
__half
*
)
beta
,
(
const
__half
*
)
beta
,
(
__half
*
)
C
,
ldc
);
(
__half
*
)
C
,
ldc
);
#endif
}
}
#endif // CUBLAS_WRAPPER_H
#endif // CUBLAS_WRAPPER_H
cuda/utils/helper_cuda.h
View file @
2a77b772
...
@@ -51,6 +51,53 @@ static const char *_cudaGetErrorEnum(CUresult error) {
...
@@ -51,6 +51,53 @@ static const char *_cudaGetErrorEnum(CUresult error) {
}
}
#endif
#endif
#ifdef MOE_HIP_DIFF
static
const
char
*
_cudaGetErrorEnum
(
cublasStatus_t
error
)
{
switch
(
error
)
{
case
rocblas_status_success
:
return
"rocblas_status_success"
;
case
rocblas_status_invalid_handle
:
return
"rocblas_status_invalid_handle"
;
case
rocblas_status_not_implemented
:
return
"rocblas_status_not_implemented"
;
case
rocblas_status_invalid_pointer
:
return
"rocblas_status_invalid_pointer:"
;
case
rocblas_status_invalid_size
:
return
"rocblas_status_invalid_size"
;
case
rocblas_status_memory_error
:
return
"rocblas_status_memory_error"
;
case
rocblas_status_internal_error
:
return
"rocblas_status_internal_error"
;
case
rocblas_status_perf_degraded
:
return
"rocblas_status_perf_degraded"
;
case
rocblas_status_size_query_mismatch
:
return
"rocblas_status_size_query_mismatch"
;
case
rocblas_status_size_increased
:
return
"rocblas_status_size_increased"
;
case
rocblas_status_size_unchanged
:
return
"rocblas_status_size_unchanged"
;
case
rocblas_status_invalid_value
:
return
"rocblas_status_invalid_value"
;
case
rocblas_status_continue
:
return
"rocblas_status_continue"
;
}
return
"<unknown>"
;
}
#else
// cuBLAS API errors
// cuBLAS API errors
static
const
char
*
_cudaGetErrorEnum
(
cublasStatus_t
error
)
{
static
const
char
*
_cudaGetErrorEnum
(
cublasStatus_t
error
)
{
switch
(
error
)
{
switch
(
error
)
{
...
@@ -87,6 +134,7 @@ static const char *_cudaGetErrorEnum(cublasStatus_t error) {
...
@@ -87,6 +134,7 @@ static const char *_cudaGetErrorEnum(cublasStatus_t error) {
return
"<unknown>"
;
return
"<unknown>"
;
}
}
#endif
#ifdef _CUFFT_H_
#ifdef _CUFFT_H_
// cuFFT API errors
// cuFFT API errors
...
...
setup.py
View file @
2a77b772
...
@@ -17,7 +17,15 @@ authors = [
...
@@ -17,7 +17,15 @@ authors = [
if
os
.
environ
.
get
(
'USE_NCCL'
,
'1'
)
==
'1'
:
if
os
.
environ
.
get
(
'USE_NCCL'
,
'1'
)
==
'1'
:
cxx_flags
.
append
(
'-DFMOE_USE_NCCL'
)
cxx_flags
.
append
(
'-DFMOE_USE_NCCL'
)
ext_libs
.
append
(
'nccl'
)
if
os
.
environ
.
get
(
'USE_ROCM'
,
'0'
)
==
'1'
:
ext_libs
.
append
(
'rccl'
)
else
:
ext_libs
.
append
(
'nccl'
)
if
os
.
environ
.
get
(
'USE_ROCM'
,
'0'
)
==
'1'
:
define_macros
=
[(
'MOE_HIP_DIFF'
,
None
)]
else
:
define_macros
=
[]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -41,6 +49,7 @@ if __name__ == '__main__':
...
@@ -41,6 +49,7 @@ if __name__ == '__main__':
'cuda/parallel_linear.cu'
,
'cuda/parallel_linear.cu'
,
'cuda/fmoe_cuda.cpp'
,
'cuda/fmoe_cuda.cpp'
,
],
],
define_macros
=
define_macros
,
extra_compile_args
=
{
extra_compile_args
=
{
'cxx'
:
cxx_flags
,
'cxx'
:
cxx_flags
,
'nvcc'
:
cxx_flags
'nvcc'
:
cxx_flags
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment