Commit 4cb75d42 authored by Rick Ho's avatar Rick Ho
Browse files

nccl runs with 2 gpus

parent d690c7b2
#include "comm_manager.h"
CommManager* comm_mgr = 0;
CommManager* getCommManager() {
if (!comm_mgr) {
comm_mgr = new CommManager();
}
return comm_mgr;
}
#ifndef COMM_MANAGER_H
#define COMM_MANAGER_H
#include <mpi.h>
#include "nccl.h"
struct CommManager {
int rank, size;
ncclComm_t ncclcomm;
CommManager() {
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
ncclUniqueId uid;
if (rank == 0) {
ncclGetUniqueId(&uid);
}
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
ncclCommInitRank(&ncclcomm, size, uid, rank);
}
};
CommManager* getCommManager();
#endif // COMM_MANAGER
......@@ -71,7 +71,7 @@ class MOELayer_raw(nn.Module):
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
print(linear.weight.shape)
# print(linear.weight.shape)
self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
......@@ -80,10 +80,10 @@ class MOELayer_raw(nn.Module):
gate_long = gate.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
print(self.weight2)
# print(self.weight2)
for i in range(batch_size):
hid = inp[i] @ self.weight1[gate_long[i]].t()
print(hid)
# print(hid)
x[i] = hid @ self.weight2[gate_long[i]].t()
return x
......
......@@ -10,15 +10,18 @@
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <mpi.h>
#include "timer.hh"
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
#include "comm_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
// #define MOE_DEBUG
// #define MOE_DEBUG_SCATTER
template <typename scalar_t>
__global__
......@@ -68,19 +71,19 @@ void moe_cuda_forward_impl(
const size_t num_expert) {
auto h = getCudaStreamManager(num_expert);
auto cm = getCommManager();
int tot_expert = num_expert * cm->size;
#ifdef MOE_BREAKDOWN
timestamp(t_init);
#endif
scalar_t *input_buf, *hidden_buf, *output_buf;
scalar_t *local_input_buf, *local_output_buf;
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat));
checkCudaErrors(cudaMalloc(&hidden_buf, sizeof(scalar_t) * batch_size *
hidden_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
#ifdef MOE_BREAKDOWN
timestamp(t_malloc);
......@@ -89,8 +92,8 @@ void moe_cuda_forward_impl(
#endif
int *gate = new int[batch_size];
int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
memset(expert_count, 0, sizeof(int) * tot_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
......@@ -105,7 +108,7 @@ void moe_cuda_forward_impl(
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
for (int i = 1; i < tot_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
......@@ -119,6 +122,39 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
int *all_expert_count = new int[tot_expert];
MPI_Alltoall(expert_count, num_expert, MPI_INT,
all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
int *expert_n = new int[num_expert];
int expert_sz = 0;
for (int i = 0; i < num_expert; ++i) {
expert_n[i] = 0;
for (int j = 0; j < cm->size; ++j) {
expert_n[i] += all_expert_count[j * num_expert + i];
}
expert_sz += expert_n[i];
}
scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
#ifdef MOE_DEBUG
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d %d %d\n", cm->rank, i, expert_count[i]);
}
if (cm->rank == 0) {
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d ",all_expert_count[i]);
}
fprintf(stderr, "\n");
}
#endif
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
......@@ -127,9 +163,36 @@ void moe_cuda_forward_impl(
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
input_buf);
local_input_buf);
h->sync(0);
ncclGroupStart();
int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int send_id = i + j * num_expert;
if (expert_count[send_id]) {
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat,
expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int recv_id = i * cm->size + j;
if (all_expert_count[recv_id]) {
ncclRecv(input_buf + recv_ptr * in_feat,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
recv_ptr += all_expert_count[recv_id];
}
}
}
ncclGroupEnd();
#ifdef MOE_BREAKDOWN
h->sync();
timestamp(t_scatter);
......@@ -140,19 +203,19 @@ void moe_cuda_forward_impl(
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
if (expert_n[i] == 0) {
continue;
}
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
fprintf(stderr, "gemm %d sz %d\n", i, expert_n[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_n[i],
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
hidden_feat, expert_count[i], in_feat,
hidden_feat, expert_n[i], in_feat,
&alpha,
weight1 + i * in_feat * hidden_feat, in_feat,
input_buf + ptr * in_feat, in_feat,
......@@ -163,7 +226,7 @@ void moe_cuda_forward_impl(
checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], hidden_feat,
out_feat, expert_n[i], hidden_feat,
&alpha,
weight2 + i * hidden_feat * out_feat, hidden_feat,
hidden_buf + hidden_feat * ptr, hidden_feat,
......@@ -171,8 +234,9 @@ void moe_cuda_forward_impl(
output_buf + out_feat * ptr, out_feat
));
ptr += expert_count[i];
ptr += expert_n[i];
}
h->sync();
#ifdef MOE_BREAKDOWN
timestamp(t_mm);
......@@ -180,10 +244,36 @@ void moe_cuda_forward_impl(
1e6);
#endif
h->sync();
ncclGroupStart();
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert;
if (expert_count[recv_id]) {
ncclRecv(local_input_buf + expert_ptr[recv_id] * in_feat,
expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int send_id = i * cm->size + j;
if (all_expert_count[send_id]) {
ncclSend(input_buf + send_ptr * in_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
send_ptr += all_expert_count[send_id];
}
}
}
ncclGroupEnd();
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf,
output);
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
local_output_buf, output);
h->sync(0);
#ifdef MOE_BREAKDOWN
......@@ -197,6 +287,8 @@ void moe_cuda_forward_impl(
cudaFree(input_buf);
cudaFree(hidden_buf);
cudaFree(output_buf);
cudaFree(local_input_buf);
cudaFree(local_output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
......
from moe import MOELayer
from moe import MOELayer, MOELayer_raw
import torch
import time
import sys
def perf():
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = int(sys.argv[1])
io_feat = int(sys.argv[2])
hidden_feat = int(sys.argv[3])
num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, io_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda()
moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda()
o = moe(inp, gate)
return
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
......@@ -42,4 +48,6 @@ def perf():
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
perf()
......@@ -12,12 +12,19 @@ setup(
sources=[
'moe.cpp',
'cuda_stream_manager.cpp',
'comm_manager.cpp',
'moe_cuda_kernel.cu',
],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)],
'nvcc': ['-I{}'.format(CUDA_HELPER)]}
)
],
extra_compile_args={
'cxx': [
'-I{}'.format(CUDA_HELPER),
],
'nvcc': [
'-I{}'.format(CUDA_HELPER),
]
}
)
],
cmdclass={
'build_ext': BuildExtension
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment