Commit 293eef6d authored by Rick Ho's avatar Rick Ho
Browse files

hack nccl of pytorch

parent a526f438
......@@ -2,12 +2,45 @@
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#ifdef MOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
#endif // MOE_USE_NCCL
#include "cuda_stream_manager.h"
#include <helper_cuda.h>
#define SMGR_N_STREAMS 16
#ifdef MOE_USE_NCCL
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index());
auto v = getNCCLComm(key, {dev});
if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n";
return 0;
}
return v[0]->getNcclComm();
}
};
void CudaStreamManager::ensure(void* torchp, at::Device dev) {
if (this->ncclgood) {
return;
}
HackNCCLGroup* h = (HackNCCLGroup*)torchp;
this->ncclcomm = h->getcomm(dev);
if (this->ncclcomm != 0) {
this->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // MOE_USE_NCCL
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
......@@ -32,17 +65,6 @@ void CudaStreamManager::setup(const int device) {
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
#ifdef MOE_USE_NCCL
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
ncclUniqueId uid;
if (rank == 0) {
ncclGetUniqueId(&uid);
}
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
#endif
}
void CudaStreamManager::destroy() {
......
......@@ -5,7 +5,6 @@
#include <cublas_v2.h>
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
......@@ -24,12 +23,13 @@ public:
cublasHandle_t* handles;
cudaStream_t* streams;
#ifdef MOE_USE_NCCL
int rank, size;
char ncclgood;
ncclComm_t ncclcomm;
void ensure(void*, class at::Device);
#endif
public:
CudaStreamManager(int device_): device(device_) {
CudaStreamManager(int device_): device(device_), ncclgood(0) {
this->setup(device);
}
......
......@@ -132,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment