Commit 293eef6d authored by Rick Ho's avatar Rick Ho
Browse files

hack nccl of pytorch

parent a526f438
...@@ -2,12 +2,45 @@ ...@@ -2,12 +2,45 @@
#include <mutex> #include <mutex>
#include <cassert> #include <cassert>
#include <thread> #include <thread>
#include <iostream>
#ifdef MOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
#endif // MOE_USE_NCCL
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#include <helper_cuda.h> #include <helper_cuda.h>
#define SMGR_N_STREAMS 16 #define SMGR_N_STREAMS 16
#ifdef MOE_USE_NCCL
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index());
auto v = getNCCLComm(key, {dev});
if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n";
return 0;
}
return v[0]->getNcclComm();
}
};
void CudaStreamManager::ensure(void* torchp, at::Device dev) {
if (this->ncclgood) {
return;
}
HackNCCLGroup* h = (HackNCCLGroup*)torchp;
this->ncclcomm = h->getcomm(dev);
if (this->ncclcomm != 0) {
this->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // MOE_USE_NCCL
cudaStream_t CudaStreamManager::stream(size_t idx) { cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS]; return this->streams[idx % SMGR_N_STREAMS];
} }
...@@ -32,17 +65,6 @@ void CudaStreamManager::setup(const int device) { ...@@ -32,17 +65,6 @@ void CudaStreamManager::setup(const int device) {
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]); cublasSetStream(handles[i], streams[i]);
} }
#ifdef MOE_USE_NCCL
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
ncclUniqueId uid;
if (rank == 0) {
ncclGetUniqueId(&uid);
}
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
#endif
} }
void CudaStreamManager::destroy() { void CudaStreamManager::destroy() {
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h> #include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \ #define NCCL_SAFE_CALL(__fn__) { \
...@@ -24,12 +23,13 @@ public: ...@@ -24,12 +23,13 @@ public:
cublasHandle_t* handles; cublasHandle_t* handles;
cudaStream_t* streams; cudaStream_t* streams;
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
int rank, size; char ncclgood;
ncclComm_t ncclcomm; ncclComm_t ncclcomm;
void ensure(void*, class at::Device);
#endif #endif
public: public:
CudaStreamManager(int device_): device(device_) { CudaStreamManager(int device_): device(device_), ncclgood(0) {
this->setup(device); this->setup(device);
} }
......
...@@ -132,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -132,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)"); m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward, m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)"); "MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif #endif
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)"); m.def("backward", &moe_backward, "MoE backward (CUDA)");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment