Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
293eef6d
Commit
293eef6d
authored
Jan 28, 2021
by
Rick Ho
Browse files
hack nccl of pytorch
parent
a526f438
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
14 deletions
+37
-14
cuda/cuda_stream_manager.cpp
cuda/cuda_stream_manager.cpp
+33
-11
cuda/cuda_stream_manager.h
cuda/cuda_stream_manager.h
+3
-3
cuda/moe.cpp
cuda/moe.cpp
+1
-0
No files found.
cuda/cuda_stream_manager.cpp
View file @
293eef6d
...
@@ -2,12 +2,45 @@
...
@@ -2,12 +2,45 @@
#include <mutex>
#include <mutex>
#include <cassert>
#include <cassert>
#include <thread>
#include <thread>
#include <iostream>
#ifdef MOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
#endif // MOE_USE_NCCL
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
#include <helper_cuda.h>
#include <helper_cuda.h>
#define SMGR_N_STREAMS 16
#define SMGR_N_STREAMS 16
#ifdef MOE_USE_NCCL
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
auto
v
=
getNCCLComm
(
key
,
{
dev
});
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
return
v
[
0
]
->
getNcclComm
();
}
};
void
CudaStreamManager
::
ensure
(
void
*
torchp
,
at
::
Device
dev
)
{
if
(
this
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)
torchp
;
this
->
ncclcomm
=
h
->
getcomm
(
dev
);
if
(
this
->
ncclcomm
!=
0
)
{
this
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
}
#endif // MOE_USE_NCCL
cudaStream_t
CudaStreamManager
::
stream
(
size_t
idx
)
{
cudaStream_t
CudaStreamManager
::
stream
(
size_t
idx
)
{
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
}
}
...
@@ -32,17 +65,6 @@ void CudaStreamManager::setup(const int device) {
...
@@ -32,17 +65,6 @@ void CudaStreamManager::setup(const int device) {
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
#ifdef MOE_USE_NCCL
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank
);
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
size
);
ncclUniqueId
uid
;
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
uid
);
}
MPI_Bcast
(
&
uid
,
sizeof
(
uid
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
);
NCCL_SAFE_CALL
(
ncclCommInitRank
(
&
ncclcomm
,
size
,
uid
,
rank
));
#endif
}
}
void
CudaStreamManager
::
destroy
()
{
void
CudaStreamManager
::
destroy
()
{
...
...
cuda/cuda_stream_manager.h
View file @
293eef6d
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#include <cublas_v2.h>
#include <cublas_v2.h>
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
#define NCCL_SAFE_CALL(__fn__) { \
...
@@ -24,12 +23,13 @@ public:
...
@@ -24,12 +23,13 @@ public:
cublasHandle_t
*
handles
;
cublasHandle_t
*
handles
;
cudaStream_t
*
streams
;
cudaStream_t
*
streams
;
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
int
rank
,
size
;
char
ncclgood
;
ncclComm_t
ncclcomm
;
ncclComm_t
ncclcomm
;
void
ensure
(
void
*
,
class
at
::
Device
);
#endif
#endif
public:
public:
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
{
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
,
ncclgood
(
0
)
{
this
->
setup
(
device
);
this
->
setup
(
device
);
}
}
...
...
cuda/moe.cpp
View file @
293eef6d
...
@@ -132,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -132,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
"MoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
#endif
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment