Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
368c8e41
Unverified
Commit
368c8e41
authored
Sep 30, 2021
by
Rick Ho
Committed by
GitHub
Sep 30, 2021
Browse files
Merge pull request #78 from hclearner/fastmoe-rocm4.0.1
add fastmoe support rocm4.0.1
parents
d2392de2
49e81f33
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
2 deletions
+90
-2
cuda/local_exchange.cuh
cuda/local_exchange.cuh
+9
-0
cuda/utils/cublas_wrapper.h
cuda/utils/cublas_wrapper.h
+17
-0
cuda/utils/helper_cuda.h
cuda/utils/helper_cuda.h
+48
-0
setup.py
setup.py
+16
-2
No files found.
cuda/local_exchange.cuh
View file @
368c8e41
...
...
@@ -27,7 +27,12 @@ void fmoe_cuda_assign_pos_impl(
}
#define PERTHREAD_EXPERTS 256
#ifdef FMOE_USE_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
__global__
void
expert_count_kernel
(
const
long
*
gate_idx
,
int
*
expert_count
,
...
...
@@ -52,7 +57,11 @@ void expert_count_kernel(const long* gate_idx, int* expert_count,
int
x
=
res_tmp
[
i
-
expert_min
];
#pragma unroll
for
(
int
j
=
1
;
j
<
WARP_SIZE
;
j
<<=
1
)
{
#ifdef FMOE_USE_HIP
x
=
x
+
__shfl_down
(
x
,
j
);
#else
x
=
x
+
__shfl_down_sync
(
-
1u
,
x
,
j
);
#endif
}
if
(
threadIdx
.
x
%
WARP_SIZE
==
0
)
{
atomicAdd
(
expert_count
+
i
,
x
);
...
...
cuda/utils/cublas_wrapper.h
View file @
368c8e41
...
...
@@ -39,7 +39,11 @@ inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
const
__half
*
beta
,
__half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
#ifdef FMOE_USE_HIP
return
rocblas_hgemm_batched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
rocblas_half
*
)
alpha
,
(
const
rocblas_half
*
const
*
)
Aarray
,
lda
,
(
const
rocblas_half
*
const
*
)
Barray
,
ldb
,
(
const
rocblas_half
*
)
beta
,
(
rocblas_half
*
const
*
)
Carray
,
ldc
,
batchCount
);
#else
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
#endif
}
...
...
@@ -73,7 +77,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const
__half
*
B
,
int
ldb
,
const
__half
*
beta
,
__half
*
C
,
int
ldc
)
{
#ifdef FMOE_USE_HIP
return
rocblas_hgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
rocblas_half
*
)
alpha
,
(
const
rocblas_half
*
)
A
,
lda
,
(
const
rocblas_half
*
)
B
,
ldb
,
(
const
rocblas_half
*
)
beta
,
(
rocblas_half
*
)
C
,
ldc
);
#else
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
#endif
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
...
...
@@ -84,12 +92,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const
c10
::
Half
*
B
,
int
ldb
,
const
c10
::
Half
*
beta
,
c10
::
Half
*
C
,
int
ldc
)
{
#ifdef FMOE_USE_HIP
return
rocblas_hgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
rocblas_half
*
)
alpha
,
(
const
rocblas_half
*
)
A
,
lda
,
(
const
rocblas_half
*
)
B
,
ldb
,
(
const
rocblas_half
*
)
beta
,
(
rocblas_half
*
)
C
,
ldc
);
#else
return
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
__half
*
)
alpha
,
(
const
__half
*
)
A
,
lda
,
(
const
__half
*
)
B
,
ldb
,
(
const
__half
*
)
beta
,
(
__half
*
)
C
,
ldc
);
#endif
}
#endif // CUBLAS_WRAPPER_H
cuda/utils/helper_cuda.h
View file @
368c8e41
...
...
@@ -51,6 +51,53 @@ static const char *_cudaGetErrorEnum(CUresult error) {
}
#endif
#ifdef FMOE_USE_HIP
static
const
char
*
_cudaGetErrorEnum
(
cublasStatus_t
error
)
{
switch
(
error
)
{
case
rocblas_status_success
:
return
"rocblas_status_success"
;
case
rocblas_status_invalid_handle
:
return
"rocblas_status_invalid_handle"
;
case
rocblas_status_not_implemented
:
return
"rocblas_status_not_implemented"
;
case
rocblas_status_invalid_pointer
:
return
"rocblas_status_invalid_pointer:"
;
case
rocblas_status_invalid_size
:
return
"rocblas_status_invalid_size"
;
case
rocblas_status_memory_error
:
return
"rocblas_status_memory_error"
;
case
rocblas_status_internal_error
:
return
"rocblas_status_internal_error"
;
case
rocblas_status_perf_degraded
:
return
"rocblas_status_perf_degraded"
;
case
rocblas_status_size_query_mismatch
:
return
"rocblas_status_size_query_mismatch"
;
case
rocblas_status_size_increased
:
return
"rocblas_status_size_increased"
;
case
rocblas_status_size_unchanged
:
return
"rocblas_status_size_unchanged"
;
case
rocblas_status_invalid_value
:
return
"rocblas_status_invalid_value"
;
case
rocblas_status_continue
:
return
"rocblas_status_continue"
;
}
return
"<unknown>"
;
}
#else
// cuBLAS API errors
static
const
char
*
_cudaGetErrorEnum
(
cublasStatus_t
error
)
{
switch
(
error
)
{
...
...
@@ -87,6 +134,7 @@ static const char *_cudaGetErrorEnum(cublasStatus_t error) {
return
"<unknown>"
;
}
#endif
#ifdef _CUFFT_H_
// cuFFT API errors
...
...
setup.py
View file @
368c8e41
import
setuptools
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
import
os
import
torch
cxx_flags
=
[]
ext_libs
=
[]
...
...
@@ -15,9 +15,22 @@ authors = [
'Qin Li'
,
]
is_rocm_pytorch
=
False
if
torch
.
__version__
>=
'1.5'
:
from
torch.utils.cpp_extension
import
ROCM_HOME
is_rocm_pytorch
=
True
if
((
torch
.
version
.
hip
is
not
None
)
and
(
ROCM_HOME
is
not
None
))
else
False
if
os
.
environ
.
get
(
'USE_NCCL'
,
'1'
)
==
'1'
:
cxx_flags
.
append
(
'-DFMOE_USE_NCCL'
)
ext_libs
.
append
(
'nccl'
)
if
is_rocm_pytorch
:
ext_libs
.
append
(
'rccl'
)
else
:
ext_libs
.
append
(
'nccl'
)
if
is_rocm_pytorch
:
define_macros
=
[(
'FMOE_USE_HIP'
,
None
)]
else
:
define_macros
=
[]
if
__name__
==
'__main__'
:
...
...
@@ -41,6 +54,7 @@ if __name__ == '__main__':
'cuda/parallel_linear.cu'
,
'cuda/fmoe_cuda.cpp'
,
],
define_macros
=
define_macros
,
extra_compile_args
=
{
'cxx'
:
cxx_flags
,
'nvcc'
:
cxx_flags
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment