Unverified Commit b8aa893b authored by Jiezhong Qiu's avatar Jiezhong Qiu Committed by GitHub
Browse files

Merge pull request #1 from xptree/laekov/multigpu

Faster MoE implementation for both single GPU and multiple GPUs
parents ef83c893 b9c28810
#include <unordered_map>
#include <mutex>
#include <cassert> #include <cassert>
#include <thread> #include <thread>
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#include <helper_cuda.h>
thread_local CudaStreamManager smgr; #define SMGR_N_STREAMS 4
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const int device) {
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
#ifdef MOE_USE_NCCL
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
/* ncclUniqueId uid;
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) { if (rank == 0) {
if (!smgr) { ncclGetUniqueId(&uid);
smgr = new CudaStreamManager(num_expert, device);
} }
assert(smgr->num_expert == num_expert); MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
assert(smgr->device == device); NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
#endif
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr; return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
} }
*/
...@@ -3,51 +3,48 @@ ...@@ -3,51 +3,48 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h>
#include <cstdio> #ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#endif
class CudaStreamManager { class CudaStreamManager {
public: public:
CudaStreamManager() : num_expert(0), device(0), streams(NULL) { int device;
int current_device; cublasHandle_t* handles;
checkCudaErrors(cudaGetDevice(&current_device)); cudaStream_t* streams;
#ifdef MOE_DEBUG #ifdef MOE_USE_NCCL
printf("constructor at device %d\n", current_device); int rank, size;
ncclComm_t ncclcomm;
#endif #endif
}
void setup(const size_t num_expert, const int device) { public:
#ifdef MOE_DEBUG CudaStreamManager(int device_): device(device_) {
printf("setup at device %d\n", device); this->setup(device);
#endif
this->num_expert = num_expert;
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
}
} }
void setup(int);
void sync(int=0);
void destroy();
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
~CudaStreamManager() { ~CudaStreamManager() {
#ifdef MOE_DEBUG this->destroy();
printf("destructor at device %d\n", device);
#endif
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
} }
checkCudaErrors(cublasDestroy(handle));
delete[] streams;
}
size_t num_expert;
int device;
cublasHandle_t handle;
cudaStream_t* streams;
}; };
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device); CudaStreamManager* getCudaStreamManager(const int device);
#endif // CUDA_STREAM_MANAGER #endif // CUDA_STREAM_MANAGER
...@@ -4,58 +4,98 @@ ...@@ -4,58 +4,98 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
std::vector<torch::Tensor> moe_cuda_forward( #include "moe_cuda_kernel.h"
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output,
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate,
size_t num_expert) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert);
}
std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos);
}
std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos);
}
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count // [batch_size]
) { ) {
CHECK_INPUT(input); CHECK_INPUT(input_buf);
CHECK_INPUT(gate);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_forward(input, gate, weight); return moe_cuda_forward(input_buf, weight, expert_count);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count
) { ) {
CHECK_INPUT(grad_output); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input); CHECK_INPUT(input_buf);
CHECK_INPUT(gate);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_backward(grad_output, input, gate, weight); return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
}
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_expert_exchange(
torch::Tensor local_expert_count,
size_t num_expert, size_t n_workers) {
return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
}
std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(input_buf);
return moe_cuda_global_scatter(input_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(output_buf);
return moe_cuda_global_gather(output_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
} }
#endif
/* /*
int main() { int main() {
...@@ -69,6 +109,14 @@ int main() { ...@@ -69,6 +109,14 @@ int main() {
*/ */
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
#ifdef MOE_USE_NCCL
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)"); m.def("backward", &moe_backward, "MoE backward (CUDA)");
} }
import math import math
from torch import nn from torch import nn
from torch.autograd import Function
import torch import torch
import moe_cuda from moe_function import moe
class MOEFunction(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
out_feat, in_feat = weight.size()[1:]
weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output = moe_cuda.forward(inp, gate, weight_column_major)
variables = [inp, gate, weight_column_major]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
# print("grad_out", grad_out)
# print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:]
# print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major
class MOELayer(nn.Module): class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=4096): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=None):
super(MOELayer, self).__init__() super(MOELayer, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.world_size = world_size
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat)) torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
...@@ -45,22 +23,26 @@ class MOELayer(nn.Module): ...@@ -45,22 +23,26 @@ class MOELayer(nn.Module):
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
return MOEFunction.apply(inp, gate.int(), self.weight) return moe(inp, gate.int(), self.weight, self.world_size)
class MOELayer_raw(nn.Module): class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=4096): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
super(MOELayer_raw, self).__init__() super(MOELayer_raw, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat)) torch.Tensor(num_expert * world_size, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat) linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
...@@ -68,75 +50,5 @@ class MOELayer_raw(nn.Module): ...@@ -68,75 +50,5 @@ class MOELayer_raw(nn.Module):
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat)) x = inp.new_zeros((batch_size, self.out_feat))
for i in range(batch_size): for i in range(batch_size):
x[i] = self.weight[gate_long[i]] @ inp[i] x[i] = inp[i] @ self.weight[gate_long[i]].t()
return x return x
def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4
num_expert = 4
in_feat = 2
out_feat = 3
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
linear.zero_grad()
moe.zero_grad()
x = linear(inp)
output = moe(x, gate)
print("moe output", output)
y = output.mean()
y.backward()
print("moe.weight.grad", moe.weight.grad)
print("linear.weight.grad", linear.weight.grad)
print("linear.bias.grad", linear.bias.grad)
linear.zero_grad()
moe.zero_grad()
x = linear(inp.clone())
output_raw= moe_raw(x, gate.clone())
print("moe_raw output", output_raw)
y_raw = output_raw.mean()
y_raw.backward()
print("moe_raw.weight.grad", moe_raw.weight.grad)
print("linear_raw.weight.grad", linear.weight.grad)
print("linear_raw.bias.grad", linear.bias.grad)
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__':
# test()
test_dp()
\ No newline at end of file
#include <torch/extension.h> #include "moe_cuda_kernel.h"
#include <torch/torch.h>
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
// #include "timer.hh" #ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#endif
#include "timer.hh"
#include "cublas_wrapper.h" #include "cublas_wrapper.h"
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
thread_local CudaStreamManager smgr;
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) { void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const int* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x; size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) { if (idx < n) {
ptrs[idx] = base + stride * offset[idx]; ptrs[idx] = base + stride * offset[idx];
} }
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( __global__
const scalar_t* input, void batch_scatter_kernel(size_t wid, const int* pos,
const int* gate, const scalar_t* inbuf, scalar_t* oubuf) {
const scalar_t* weight, inbuf += wid * blockIdx.x;
scalar_t* output, oubuf += wid * pos[blockIdx.x];
const size_t batch_size, for (int i = threadIdx.x; i < wid; i += blockDim.x) {
const size_t in_feat, oubuf[i] = inbuf[i];
const size_t out_feat, }
}
void moe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
int* d_pos,
const size_t num_expert, const size_t num_expert,
cublasOperation_t transb) { const size_t batch_size) {
int *gate = new int[batch_size];
int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams))); checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
// setup Aarray, Barray and Carray for (int i = 0; i < batch_size; ++i) {
std::vector<const scalar_t*> aptrs; ++expert_count[gate[i]];
std::vector<scalar_t*> cptrs; }
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
#ifdef MOE_USE_NCCL
void moe_cuda_expert_exchange_impl(
const int* local_expert_count,
int* global_expert_count,
int* fwd_expert_count,
int num_expert, int world_size) {
MPI_Alltoall(local_expert_count, num_expert, MPI_INT,
global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < world_size; ++j) {
fwd_expert_count[i] += global_expert_count[i + j * num_expert];
}
}
}
const scalar_t **Aarray; std::vector<torch::Tensor> moe_cuda_expert_exchange(
const scalar_t **Barray; torch::Tensor local_expert_count,
scalar_t **Carray; long num_expert, long n_workers) {
checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*))); auto global_expert_count = torch::empty_like(local_expert_count);
checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*))); auto fwe_options = torch::TensorOptions()
checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*))); .dtype(local_expert_count.dtype());
auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
fwd_expert_count.data_ptr<int>(),
num_expert, n_workers);
return {global_expert_count, fwd_expert_count};
}
for (size_t i=0; i<batch_size; ++i) { template<typename scalar_t>
aptrs.push_back(input + in_feat * i); void moe_cuda_global_scatter_impl(
cptrs.push_back(output + out_feat * i); const scalar_t* local_input_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
} }
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
scalar_t*), cudaMemcpyHostToDevice));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
sizeof(scalar_t*), cudaMemcpyHostToDevice));
dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256); for (int i = 0; i < num_expert; ++i) {
generate_ptr_offset_kernel<<<griddim, blockdim, 0, NCCL_SAFE_CALL(ncclGroupStart());
*(smgr.streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray); for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(),
"moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
template<typename scalar_t>
void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
int send_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#endif // MOE_USE_NCCL
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
const int* d_pos,
scalar_t* input_buf,
const long batch_size,
const long in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void moe_cuda_local_gather_impl(
const scalar_t* output_buf,
const int* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(smgr.handle, for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
transb, out_feat, expert_count[i], in_feat,
1, out_feat, in_feat,
&alpha, &alpha,
Aarray, 1, weight + i * in_feat * out_feat, in_feat,
Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat, input_buf + ptr * in_feat, in_feat,
&beta, &beta,
Carray, 1, output_buf + out_feat * ptr, out_feat
batch_size)); ));
checkCudaErrors(cudaStreamSynchronize(*(smgr.streams))); ptr += expert_count[i];
checkCudaErrors(cudaFree(Aarray)); }
checkCudaErrors(cudaFree(Barray)); smgr->sync(num_expert);
checkCudaErrors(cudaFree(Carray));
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_grad_weight( void moe_cuda_backward_impl(
const scalar_t* input, const scalar_t* grad_output_buf,
const int* gate, const scalar_t* input_buf,
const scalar_t* grad_output, const scalar_t* weight,
scalar_t* grad_weight, // [num_expert x out_feat x in_feat] const int* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert,
CudaStreamManager* smgr) {
int* gate_host = new int[batch_size]; scalar_t alpha = 1, beta = 0;
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); for (int i = 0, ptr = 0; i < num_expert; ++i) {
for (size_t i=0; i<batch_size; ++i) { if (expert_count[i] == 0) {
checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams + gate_host[i]))); cudaMemset(grad_weight + i * in_feat * out_feat, 0,
checkCudaErrors(cublasXgemm(smgr.handle, sizeof(scalar_t) * in_feat * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
out_feat, in_feat, out_feat, expert_count[i],
in_feat,
1,
&alpha, &alpha,
grad_output + i * out_feat, input_buf + in_feat * ptr, in_feat,
out_feat, grad_output_buf + ptr * out_feat, out_feat,
input + i * in_feat,
in_feat,
&beta, &beta,
grad_weight + gate_host[i] * out_feat * in_feat, grad_weight + i * in_feat * out_feat, in_feat
out_feat)); ));
}
for (size_t i=0; i<num_expert; ++i) { ptr += expert_count[i];
checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
} }
delete[] gate_host; smgr->sync(num_expert);
} }
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, torch::Tensor gate,
torch::Tensor weight) { size_t num_expert) {
const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
moe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
const auto in_feat = input.size(1);
auto input_buf = torch::empty_like(input);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda",
([&] {
moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = output_buf.size(0);
const auto out_feat = output_buf.size(1);
auto output = torch::empty_like(output_buf);
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] {
moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat,
smgr);
}));
return {output,};
}
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif #endif
const int device = device_of(input).value().index(); auto out_options = torch::TensorOptions()
if (smgr.streams == NULL) { .device(input_buf.device())
smgr.setup(num_expert, device); .dtype(input_buf.dtype());
} auto output = torch::empty({batch_size, out_feat}, out_options);
auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda",
([&] {
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
input.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<int>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
CUBLAS_OP_T smgr
); );
})); }));
...@@ -162,56 +487,43 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -162,56 +487,43 @@ std::vector<torch::Tensor> moe_cuda_forward(
} }
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count
) { ) {
const auto batch_size = input.size(0); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif #endif
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
// grad_input is easy to compute, exactly the same as forward AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] { moe_cuda_backward_impl<scalar_t>(
moe_cuda_forward_impl<scalar_t>( grad_output_buf.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), expert_count.data_ptr<int>(),
batch_size, grad_input_buf.data_ptr<scalar_t>(),
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
grad_output.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert num_expert,
smgr
); );
})); }));
return {grad_input, grad_weight}; return {grad_input_buf, grad_weight};
} }
......
#ifndef MOE_CUDA_KERNEL_H
#define MOE_CUDA_KERNEL_H
#include <vector>
#include <torch/extension.h>
#include <torch/torch.h>
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers);
#endif
#endif // MOE_CUDA_KERNEL_H
import torch
from torch.autograd import Function
import moe_cuda
class MOELocal(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOEGlobal(Function):
@staticmethod
def forward(ctx, inp, gate, weight, world_size):
num_expert = weight.shape[0]
local_expert_count, pos = moe_cuda.expert_count(gate,
world_size * num_expert)
global_expert_count, fwd_expert_count = moe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = moe_cuda.local_scatter(inp, pos)
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
global_output_buf, = moe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
inp.shape[0], world_size)
output, = moe_cuda.local_gather(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos)
ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
(input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos) = ctx.saved_tensors
num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
grad_inp_buf, grad_weight = moe_cuda.backward(
global_grad_out_buf, input_buf, weight, fwd_expert_count)
local_grad_inp_buf, = moe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight, None
def moe(inp, gate, weight, world_size):
if world_size is not None and world_size > 1:
return MOEGlobal.apply(inp, gate, weight, world_size)
else:
return MOELocal.apply(inp, gate, weight)
from moe import MOELayer from moe import MOELayer, MOELayer_raw
import torch import torch
from torch import nn
import time import time
import sys import sys
dev_name_default = 'cuda:0'
def perf(): def perf():
batch_size = int(sys.argv[1]) torch.manual_seed(42 + torch.distributed.get_rank())
in_feat = int(sys.argv[2]) torch.cuda.manual_seed(42 + torch.distributed.get_rank())
out_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) if len(sys.argv) == 6:
batch_size = int(sys.argv[2])
in_feat = int(sys.argv[3])
out_feat = int(sys.argv[4])
num_expert = int(sys.argv[5])
else:
batch_size = 4096
in_feat = 1024
out_feat = 4096
num_expert = 4
if torch.distributed.get_rank() == 0:
print('Performance test case bs {} {}x{} ne {}'.format(batch_size,
in_feat, out_feat, num_expert))
if torch.distributed.get_world_size() > 1:
dev_name = 'cuda'
else:
dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda("cuda:1") inp = torch.rand(batch_size, in_feat).cuda(dev_name)
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda("cuda:1") gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, in_feat, out_feat).cuda("cuda:1") moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name)
moe.train()
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
n_runs = 16 n_runs = 16
tott = 0. tott = 0.
backt = 0.
maxt = 0.
sqtot = 0.
for i in range(n_runs): for i in range(n_runs):
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda("cuda:1") gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
ts = time.time() ts = time.time()
o = moe(inp, gate) o = moe(inp, gate)
te = time.time() te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Mean time {:.3f} ms, {:.3f} GFLOPs'.format(tott * 1e3 / n_runs, gflops)) print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
backt * 1e3 / n_runs, gflops))
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
# print('ooutput', torch.distributed.get_rank(), output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = 4
num_expert = 2
in_feat = 6
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda()
if world_size > 1:
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
else:
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone()
else:
weight_array = [torch.empty_like(moe.weight.data).cpu()
for _ in range(world_size)]
torch.distributed.all_gather(weight_array, moe.weight.data.cpu())
moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
high=num_expert * world_size,
size=(batch_size,),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
if world_size > 1:
rank = torch.distributed.get_rank()
ou, wg, lwg, lbg = raw_out
wg = wg.cpu()
torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg.cuda(), lwg, lbg
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
world_size = torch.distributed.get_world_size()
if len(sys.argv) == 2:
task = sys.argv[1]
print('Specificed task {}'.format(task))
if task == 'correctness':
test()
elif task == 'dp':
test_dp()
elif task == 'performance':
perf() perf()
else:
test()
#!/bin/bash #!/bin/bash
if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ]
then
export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK
fi
export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7 export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ] if [ -z $1 ]
then then
python moe.py python3 moe_test.py 2>logs/$OMPI_COMM_WORLD_RANK.log
elif [ .$1 = '.test_all' ]
then
for bs in 4 16 64
do
for inf in 1024 4096
do
for ouf in 1024 4096
do
for nexp in 4 16 64
do
echo $bs $nexp ${inf}x${ouf}
python moe_test.py $bs $inf $ouf $nexp
done
done
done
done
else else
python $@ python3 $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
fi fi
...@@ -11,11 +11,19 @@ setup( ...@@ -11,11 +11,19 @@ setup(
name='moe_cuda', name='moe_cuda',
sources=[ sources=[
'moe.cpp', 'moe.cpp',
# 'cuda_stream_manager.cpp', 'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
], ],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)], extra_compile_args={
'nvcc': ['-I{}'.format(CUDA_HELPER)]} 'cxx': [
'-I{}'.format(CUDA_HELPER),
'-DMOE_USE_NCCL'
],
'nvcc': [
'-I{}'.format(CUDA_HELPER),
'-DMOE_USE_NCCL'
]
}
) )
], ],
cmdclass={ cmdclass={
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment