Commit 94b68a3d authored by Rick Ho's avatar Rick Ho
Browse files

overlap allreduce with computation

parent 7c3e5149
......@@ -95,6 +95,20 @@ std::vector<torch::Tensor> moe_global_gather(
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
return moe_cuda_global_fused_forward(
input_buf, weight, local_expert_count, global_expert_count,
global_batch_size, local_batch_size, n_workers);
}
#endif
/*
......@@ -116,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)");
......
......@@ -45,6 +45,13 @@ std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers);
#endif
#endif // MOE_CUDA_KERNEL_H
......@@ -40,16 +40,12 @@ class MOEGlobal(Function):
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = moe_cuda.local_scatter(inp, pos)
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
global_output_buf, = moe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_output_buf, global_input_buf = moe_cuda.global_fused_forward(
local_input_buf, weight,
local_expert_count, global_expert_count,
inp.shape[0], world_size)
fwd_batch_size, inp.shape[0], world_size)
output, = moe_cuda.local_gather(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
......
......@@ -11,12 +11,138 @@
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
// TODO
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
int ptr = 0;
int send_ptr = 0;
int recv_ptr = 0;
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
scalar_t alpha = 1, beta = 0;
for (int i = 0; i < num_expert; ++i) {
int expert_count = 0;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
global_input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
recv_ptr += global_expert_count[idx];
expert_count += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count, in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
global_input_buf + ptr * in_feat, in_feat,
&beta,
global_output_buf + out_feat * ptr, out_feat
));
ptr += expert_count;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
global_output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
return {output_buf, global_input_buf};
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment