Commit ec322e4b authored by Rick Ho's avatar Rick Ho
Browse files

global scatter gather kernels and pytorch C function

parent 7a2ad4a1
#include "comm_manager.h"
CommManager* comm_mgr = 0;
CommManager* getCommManager() {
if (!comm_mgr) {
comm_mgr = new CommManager();
}
return comm_mgr;
}
#ifndef COMM_MANAGER_H
#define COMM_MANAGER_H
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#include <mpi.h>
#include "nccl.h"
struct CommManager {
int rank, size;
ncclComm_t ncclcomm;
CommManager() {
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
ncclUniqueId uid;
if (rank == 0) {
ncclGetUniqueId(&uid);
}
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
}
};
CommManager* getCommManager();
#endif // COMM_MANAGER
...@@ -32,6 +32,17 @@ void CudaStreamManager::setup(const int device) { ...@@ -32,6 +32,17 @@ void CudaStreamManager::setup(const int device) {
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]); cublasSetStream(handles[i], streams[i]);
} }
#ifdef MOE_USE_NCCL
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
ncclUniqueId uid;
if (rank == 0) {
ncclGetUniqueId(&uid);
}
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
#endif
} }
void CudaStreamManager::destroy() { void CudaStreamManager::destroy() {
......
...@@ -4,11 +4,29 @@ ...@@ -4,11 +4,29 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#endif
class CudaStreamManager { class CudaStreamManager {
public: public:
int device; int device;
cublasHandle_t* handles; cublasHandle_t* handles;
cudaStream_t* streams; cudaStream_t* streams;
#ifdef MOE_USE_NCCL
int rank, size;
ncclComm_t ncclcomm;
#endif
public: public:
CudaStreamManager(int device_): device(device_) { CudaStreamManager(int device_): device(device_) {
......
...@@ -4,29 +4,7 @@ ...@@ -4,29 +4,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
std::vector<torch::Tensor> moe_cuda_expert_count( #include "moe_cuda_kernel.h"
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
...@@ -87,6 +65,31 @@ std::vector<torch::Tensor> moe_backward( ...@@ -87,6 +65,31 @@ std::vector<torch::Tensor> moe_backward(
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count); return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
} }
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(input_buf);
return moe_cuda_global_scatter(input_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(output_buf);
return moe_cuda_global_gather(output_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
#endif
/* /*
int main() { int main() {
...@@ -103,6 +106,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -103,6 +106,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)"); m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)"); m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)"); m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
#ifdef MOE_USE_NCCL
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)"); m.def("backward", &moe_backward, "MoE backward (CUDA)");
} }
...@@ -102,8 +102,11 @@ def test(): ...@@ -102,8 +102,11 @@ def test():
moe_raw.weight.data = moe.weight.data.clone() moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0,
gate = torch.Tensor([0, 1, 0, 1]).int().cuda() high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone()) moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
......
#include <torch/extension.h> #include "moe_cuda_kernel.h"
#include <torch/torch.h>
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -10,13 +10,16 @@ ...@@ -10,13 +10,16 @@
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#ifdef MOE_USE_NCCL
#include <mpi.h> #include <mpi.h>
#include <nccl.h>
#endif
#include "timer.hh" #include "timer.hh"
#include "cublas_wrapper.h" #include "cublas_wrapper.h"
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#include "comm_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
...@@ -79,79 +82,146 @@ void moe_cuda_expert_count_impl( ...@@ -79,79 +82,146 @@ void moe_cuda_expert_count_impl(
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
void moe_cuda_global_scatter() { template<typename scalar_t>
if (cm->size > 1) { void moe_cuda_global_scatter_impl(
if (expert_sz) { const scalar_t* local_input_buf,
checkCudaErrors(cudaMalloc(&input_buf, const int* local_expert_count,
sizeof(scalar_t) * expert_sz * in_feat)); const int* global_expert_count,
checkCudaErrors(cudaMalloc(&output_buf, scalar_t* input_buf,
sizeof(scalar_t) * expert_sz * out_feat)); size_t in_feat, size_t num_expert, size_t world_size,
} CudaStreamManager* smgr) {
int recv_ptr = 0; // assert world_size > 1
for (int i = 0; i < num_expert; ++i) { int recv_ptr = 0;
NCCL_SAFE_CALL(ncclGroupStart()); /* TODO: may save for backward */
for (int j = 0; j < cm->size; ++j) { int *expert_ptr = new int[num_expert * world_size];
int idx = i + j * num_expert; expert_ptr[0] = 0;
if (expert_count[idx]) { for (int i = 1; i < num_expert * world_size; ++i) {
NCCL_SAFE_CALL(ncclSend( expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
local_input_buf + expert_ptr[idx] * in_feat, }
expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar, for (int i = 0; i < num_expert; ++i) {
j, NCCL_SAFE_CALL(ncclGroupStart());
cm->ncclcomm, for (int j = 0; j < world_size; ++j) {
h->getStream(0))); int idx = i + j * num_expert;
} if (local_expert_count[idx]) {
if (all_expert_count[idx]) { NCCL_SAFE_CALL(ncclSend(
NCCL_SAFE_CALL(ncclRecv( local_input_buf + expert_ptr[idx] * in_feat,
input_buf + recv_ptr * in_feat, local_expert_count[idx] * in_feat * sizeof(scalar_t),
all_expert_count[idx] * in_feat * sizeof(scalar_t), ncclChar,
ncclChar, j,
j, smgr->ncclcomm,
cm->ncclcomm, smgr->stream(0)));
h->getStream(0))); }
recv_ptr += all_expert_count[idx]; if (global_expert_count[idx]) {
} NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
} }
NCCL_SAFE_CALL(ncclGroupEnd());
} }
} else { NCCL_SAFE_CALL(ncclGroupEnd());
input_buf = local_input_buf;
output_buf = local_output_buf;
} }
delete [] expert_ptr;
} }
void moe_cuda_global_gather() { std::vector<torch::Tensor> moe_cuda_global_scatter(
if (cm->size > 1) { torch::Tensor input_buf,
int send_ptr = 0; torch::Tensor local_expert_count,
for (int i = 0; i < num_expert; ++i) { torch::Tensor global_expert_count,
NCCL_SAFE_CALL(ncclGroupStart()); long batch_size, long n_workers) {
for (int j = 0; j < cm->size; ++j) { auto num_expert = local_expert_count.size(0) / n_workers;
int idx = i + j * num_expert; auto in_feat = input_buf.size(1);
if (all_expert_count[idx]) { auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
NCCL_SAFE_CALL(ncclSend( auto smgr = getCudaStreamManager(input_buf.device().index());
output_buf + send_ptr * out_feat,
all_expert_count[idx] * out_feat * sizeof(scalar_t), AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(),
ncclChar, "moe_cuda_global_scatter", ([&] {
j, moe_cuda_global_scatter_impl<scalar_t>(
cm->ncclcomm, input_buf.data_ptr<scalar_t>(),
h->getStream(0))); local_expert_count.data_ptr<int>(),
send_ptr += all_expert_count[idx]; global_expert_count.data_ptr<int>(),
} global_input_buf.data_ptr<scalar_t>(),
if (expert_count[idx]) { in_feat, num_expert, n_workers,
NCCL_SAFE_CALL(ncclRecv( smgr
local_output_buf + expert_ptr[idx] * out_feat, );
expert_count[idx] * out_feat * sizeof(scalar_t), }));
ncclChar, return {global_input_buf,};
j, }
cm->ncclcomm,
h->getStream(0))); template<typename scalar_t>
} void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
int send_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
} }
NCCL_SAFE_CALL(ncclGroupEnd());
} }
NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr;
}
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
} }
#endif // MOE_USE_NCCL #endif // MOE_USE_NCCL
template <typename scalar_t> template <typename scalar_t>
...@@ -159,8 +229,8 @@ void moe_cuda_local_scatter_impl( ...@@ -159,8 +229,8 @@ void moe_cuda_local_scatter_impl(
const scalar_t* input, const scalar_t* input,
const int* d_pos, const int* d_pos,
scalar_t* input_buf, scalar_t* input_buf,
const size_t batch_size, const long batch_size,
const size_t in_feat, const long in_feat,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
......
#ifndef MOE_CUDA_KERNEL_H
#define MOE_CUDA_KERNEL_H
#include <vector>
#include <torch/extension.h>
#include <torch/torch.h>
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
#endif
#endif // MOE_CUDA_KERNEL_H
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import sys import sys
dev_name = 'cuda:0' dev_name = 'cuda:1'
def perf(): def perf():
...@@ -16,7 +16,7 @@ def perf(): ...@@ -16,7 +16,7 @@ def perf():
out_feat = int(sys.argv[3]) out_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, io_feat).cuda(dev_name) inp = torch.rand(batch_size, in_feat).cuda(dev_name)
gate = torch.randint(low=0, gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(), high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda(dev_name) size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
......
...@@ -26,5 +26,5 @@ then ...@@ -26,5 +26,5 @@ then
done done
done done
else else
python3 $@ 2>logs/$OMPI_COMM_WORLD_RANK.log python3 $@ # 2>logs/$OMPI_COMM_WORLD_RANK.log
fi fi
...@@ -12,15 +12,16 @@ setup( ...@@ -12,15 +12,16 @@ setup(
sources=[ sources=[
'moe.cpp', 'moe.cpp',
'cuda_stream_manager.cpp', 'cuda_stream_manager.cpp',
'comm_manager.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
], ],
extra_compile_args={ extra_compile_args={
'cxx': [ 'cxx': [
'-I{}'.format(CUDA_HELPER), '-I{}'.format(CUDA_HELPER),
'-DMOE_USE_NCCL'
], ],
'nvcc': [ 'nvcc': [
'-I{}'.format(CUDA_HELPER), '-I{}'.format(CUDA_HELPER),
'-DMOE_USE_NCCL'
] ]
} }
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment