Commit 4cb75d42 authored by Rick Ho's avatar Rick Ho
Browse files

nccl runs with 2 gpus

parent d690c7b2
#include "comm_manager.h"
CommManager* comm_mgr = 0;
CommManager* getCommManager() {
if (!comm_mgr) {
comm_mgr = new CommManager();
}
return comm_mgr;
}
#ifndef COMM_MANAGER_H
#define COMM_MANAGER_H
#include <mpi.h>
#include "nccl.h"
struct CommManager {
int rank, size;
ncclComm_t ncclcomm;
CommManager() {
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
ncclUniqueId uid;
if (rank == 0) {
ncclGetUniqueId(&uid);
}
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
ncclCommInitRank(&ncclcomm, size, uid, rank);
}
};
CommManager* getCommManager();
#endif // COMM_MANAGER
...@@ -71,7 +71,7 @@ class MOELayer_raw(nn.Module): ...@@ -71,7 +71,7 @@ class MOELayer_raw(nn.Module):
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
print(linear.weight.shape) # print(linear.weight.shape)
self.weight1.data[i] = linear.weight.data self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat) linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data self.weight2.data[i] = linear.weight.data
...@@ -80,10 +80,10 @@ class MOELayer_raw(nn.Module): ...@@ -80,10 +80,10 @@ class MOELayer_raw(nn.Module):
gate_long = gate.long() gate_long = gate.long()
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat)) x = inp.new_zeros((batch_size, self.out_feat))
print(self.weight2) # print(self.weight2)
for i in range(batch_size): for i in range(batch_size):
hid = inp[i] @ self.weight1[gate_long[i]].t() hid = inp[i] @ self.weight1[gate_long[i]].t()
print(hid) # print(hid)
x[i] = hid @ self.weight2[gate_long[i]].t() x[i] = hid @ self.weight2[gate_long[i]].t()
return x return x
......
...@@ -10,15 +10,18 @@ ...@@ -10,15 +10,18 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
#include <mpi.h>
#include "timer.hh" #include "timer.hh"
#include "cublas_wrapper.h" #include "cublas_wrapper.h"
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#include "comm_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN // #define MOE_BREAKDOWN
// #define MOE_DEBUG // #define MOE_DEBUG_SCATTER
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
...@@ -68,19 +71,19 @@ void moe_cuda_forward_impl( ...@@ -68,19 +71,19 @@ void moe_cuda_forward_impl(
const size_t num_expert) { const size_t num_expert) {
auto h = getCudaStreamManager(num_expert); auto h = getCudaStreamManager(num_expert);
auto cm = getCommManager();
int tot_expert = num_expert * cm->size;
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_init); timestamp(t_init);
#endif #endif
scalar_t *input_buf, *hidden_buf, *output_buf; scalar_t *local_input_buf, *local_output_buf;
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
in_feat)); in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&local_output_buf,
out_feat)); sizeof(scalar_t) * batch_size * out_feat));
checkCudaErrors(cudaMalloc(&hidden_buf, sizeof(scalar_t) * batch_size *
hidden_feat));
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_malloc); timestamp(t_malloc);
...@@ -89,8 +92,8 @@ void moe_cuda_forward_impl( ...@@ -89,8 +92,8 @@ void moe_cuda_forward_impl(
#endif #endif
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert]; int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
memset(expert_count, 0, sizeof(int) * num_expert); memset(expert_count, 0, sizeof(int) * tot_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
...@@ -105,7 +108,7 @@ void moe_cuda_forward_impl( ...@@ -105,7 +108,7 @@ void moe_cuda_forward_impl(
++expert_count[gate[i]]; ++expert_count[gate[i]];
} }
expert_ptr[0] = 0; expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) { for (int i = 1; i < tot_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
} }
...@@ -119,6 +122,39 @@ void moe_cuda_forward_impl( ...@@ -119,6 +122,39 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
int *all_expert_count = new int[tot_expert];
MPI_Alltoall(expert_count, num_expert, MPI_INT,
all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
int *expert_n = new int[num_expert];
int expert_sz = 0;
for (int i = 0; i < num_expert; ++i) {
expert_n[i] = 0;
for (int j = 0; j < cm->size; ++j) {
expert_n[i] += all_expert_count[j * num_expert + i];
}
expert_sz += expert_n[i];
}
scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
#ifdef MOE_DEBUG
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d %d %d\n", cm->rank, i, expert_count[i]);
}
if (cm->rank == 0) {
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d ",all_expert_count[i]);
}
fprintf(stderr, "\n");
}
#endif
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_expert); timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) * fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
...@@ -127,9 +163,36 @@ void moe_cuda_forward_impl( ...@@ -127,9 +163,36 @@ void moe_cuda_forward_impl(
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
input_buf); local_input_buf);
h->sync(0); h->sync(0);
ncclGroupStart();
int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int send_id = i + j * num_expert;
if (expert_count[send_id]) {
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat,
expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int recv_id = i * cm->size + j;
if (all_expert_count[recv_id]) {
ncclRecv(input_buf + recv_ptr * in_feat,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
recv_ptr += all_expert_count[recv_id];
}
}
}
ncclGroupEnd();
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync(); h->sync();
timestamp(t_scatter); timestamp(t_scatter);
...@@ -140,19 +203,19 @@ void moe_cuda_forward_impl( ...@@ -140,19 +203,19 @@ void moe_cuda_forward_impl(
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_n[i] == 0) {
continue; continue;
} }
#ifdef MOE_DEBUG_SCATTER #ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]); fprintf(stderr, "gemm %d sz %d\n", i, expert_n[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i], fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_n[i],
in_feat); in_feat);
#endif #endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i), checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
hidden_feat, expert_count[i], in_feat, hidden_feat, expert_n[i], in_feat,
&alpha, &alpha,
weight1 + i * in_feat * hidden_feat, in_feat, weight1 + i * in_feat * hidden_feat, in_feat,
input_buf + ptr * in_feat, in_feat, input_buf + ptr * in_feat, in_feat,
...@@ -163,7 +226,7 @@ void moe_cuda_forward_impl( ...@@ -163,7 +226,7 @@ void moe_cuda_forward_impl(
checkCudaErrors(cublasXgemm(h->getHandle(i), checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], hidden_feat, out_feat, expert_n[i], hidden_feat,
&alpha, &alpha,
weight2 + i * hidden_feat * out_feat, hidden_feat, weight2 + i * hidden_feat * out_feat, hidden_feat,
hidden_buf + hidden_feat * ptr, hidden_feat, hidden_buf + hidden_feat * ptr, hidden_feat,
...@@ -171,8 +234,9 @@ void moe_cuda_forward_impl( ...@@ -171,8 +234,9 @@ void moe_cuda_forward_impl(
output_buf + out_feat * ptr, out_feat output_buf + out_feat * ptr, out_feat
)); ));
ptr += expert_count[i]; ptr += expert_n[i];
} }
h->sync();
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_mm); timestamp(t_mm);
...@@ -180,10 +244,36 @@ void moe_cuda_forward_impl( ...@@ -180,10 +244,36 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
h->sync(); ncclGroupStart();
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert;
if (expert_count[recv_id]) {
ncclRecv(local_input_buf + expert_ptr[recv_id] * in_feat,
expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int send_id = i * cm->size + j;
if (all_expert_count[send_id]) {
ncclSend(input_buf + send_ptr * in_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
send_ptr += all_expert_count[send_id];
}
}
}
ncclGroupEnd();
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf, <<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
output); local_output_buf, output);
h->sync(0); h->sync(0);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
...@@ -197,6 +287,8 @@ void moe_cuda_forward_impl( ...@@ -197,6 +287,8 @@ void moe_cuda_forward_impl(
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(hidden_buf); cudaFree(hidden_buf);
cudaFree(output_buf); cudaFree(output_buf);
cudaFree(local_input_buf);
cudaFree(local_output_buf);
cudaFree(d_pos); cudaFree(d_pos);
delete [] pos; delete [] pos;
delete [] gate; delete [] gate;
......
from moe import MOELayer from moe import MOELayer, MOELayer_raw
import torch import torch
import time import time
import sys import sys
def perf(): def perf():
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = int(sys.argv[1]) batch_size = int(sys.argv[1])
io_feat = int(sys.argv[2]) io_feat = int(sys.argv[2])
hidden_feat = int(sys.argv[3]) hidden_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, io_feat).cuda() inp = torch.rand(batch_size, io_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda()
moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda() moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda()
o = moe(inp, gate) o = moe(inp, gate)
return
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
...@@ -42,4 +48,6 @@ def perf(): ...@@ -42,4 +48,6 @@ def perf():
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
perf() perf()
...@@ -12,10 +12,17 @@ setup( ...@@ -12,10 +12,17 @@ setup(
sources=[ sources=[
'moe.cpp', 'moe.cpp',
'cuda_stream_manager.cpp', 'cuda_stream_manager.cpp',
'comm_manager.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
], ],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)], extra_compile_args={
'nvcc': ['-I{}'.format(CUDA_HELPER)]} 'cxx': [
'-I{}'.format(CUDA_HELPER),
],
'nvcc': [
'-I{}'.format(CUDA_HELPER),
]
}
) )
], ],
cmdclass={ cmdclass={
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment