Commit 307e0ad9 authored by Rick Ho's avatar Rick Ho
Browse files

Merge branch 'laekov/batching' into laekov/multigpu

parents cbd86de8 861b75c1
#include <cuda_runtime.h>
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include "cuda_stream_manager.h"
#include <helper_cuda.h>
CudaStreamManager* smgr = NULL;
#define SMGR_N_STREAMS 4
CudaStreamManager* getCudaStreamManager(const size_t num_expert) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert);
}
return smgr;
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int i) {
if (i > -1) {
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
return;
}
for (size_t i=0; i<MAX_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
void CudaStreamManager::setup(const int device) {
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
......@@ -3,44 +3,30 @@
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#define MAX_STREAMS 16
struct CudaStreamManager {
const size_t num_expert;
class CudaStreamManager {
public:
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[MAX_STREAMS];
handles = new cublasHandle_t[MAX_STREAMS];
for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cublasCreate(handles + i));
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasSetStream(handles[i], streams[i]));
}
public:
CudaStreamManager(int device_): device(device_) {
this->setup(device);
}
~CudaStreamManager() {
for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
}
void setup(int);
void sync(int=0);
void destroy();
inline cudaStream_t& getStream(int idx) {
return streams[idx % MAX_STREAMS];
}
inline cublasHandle_t& getHandle(int idx) {
return handles[idx % MAX_STREAMS];
}
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
void sync(int=-1);
~CudaStreamManager() {
this->destroy();
}
};
CudaStreamManager* getCudaStreamManager(const size_t num_expert);
CudaStreamManager* getCudaStreamManager(const int device);
#endif // CUDA_STREAM_MANAGER
......@@ -4,17 +4,27 @@
#include <iostream>
#include <vector>
std::vector<torch::Tensor> moe_cuda_forward(
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight1,
torch::Tensor weight2);
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output,
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight);
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
// C++ interface
......@@ -23,40 +33,58 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate,
size_t num_expert) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert);
}
std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos);
}
std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos);
}
std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight1, // [num_expert x hidden_feat x in_feat]
torch::Tensor weight2 // [num_expert x out_feat x hidden_feat]
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count // [batch_size]
) {
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(weight1);
CHECK_INPUT(weight2);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_forward(input, gate, weight1, weight2);
return moe_cuda_forward(input_buf, weight, expert_count);
}
std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
CHECK_INPUT(grad_output);
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_backward(grad_output, input, gate, weight);
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
}
......@@ -72,6 +100,9 @@ int main() {
*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)");
}
......@@ -5,86 +5,75 @@ import torch
import moe_cuda
torch.manual_seed(42)
torch.cuda.manual_seed(42)
class MOEFunction(Function):
@staticmethod
def forward(ctx, inp, gate, weight1, weight2):
def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output = moe_cuda.forward(inp, gate, weight1, weight2)
variables = [inp, gate, weight1, weight2]
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
# print("grad_out", grad_out)
# print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:]
# print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(MOELayer, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat
self.weight1 = nn.Parameter(
torch.Tensor(num_expert, hidden_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
return MOEFunction.apply(inp, gate, self.weight1, self.weight2)
return MOEFunction.apply(inp, gate.int(), self.weight)
class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(MOELayer_raw, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat
self.weight1 = nn.Parameter(
torch.Tensor(num_expert, hidden_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
# print(linear.weight.shape)
self.weight1.data[i] = (linear.weight.data)
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = (linear.weight.data)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
gate_long = gate.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
# print(self.weight2)
for i in range(batch_size):
hid = inp[i] @ self.weight1[gate_long[i]].t()
# print(hid)
x[i] = hid @ self.weight2[gate_long[i]].t()
x[i] = inp[i] @ self.weight[gate_long[i]].t()
return x
......@@ -93,28 +82,24 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
# print(output)
if torch.distributed.get_rank() == 1:
print(output)
return output
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4
num_expert = 2
in_feat = 6
hidden_feat = 12
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, hidden_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, hidden_feat, out_feat).cuda()
moe_raw.weight1.data = moe.weight1.data.clone()
moe_raw.weight2.data = moe.weight2.data.clone()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
......@@ -124,11 +109,36 @@ def test():
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
names = ['Out']
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
test()
# test_dp()
......@@ -4,11 +4,11 @@
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <mpi.h>
......@@ -20,10 +20,6 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_DEBUG
#define MOE_BREAKDOWN
// #define MOE_DEBUG_SCATTER
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
......@@ -34,10 +30,9 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
}
}
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos,
void batch_scatter_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
......@@ -46,55 +41,15 @@ void batch_scatter_kernel(int wid, int* pos,
}
}
template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
scalar_t print_first_float(scalar_t* d_ptr) {
scalar_t v;
cudaMemcpy(&v, d_ptr, sizeof(scalar_t), cudaMemcpyDeviceToHost);
return v;
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input,
void moe_cuda_expert_count_impl(
const int* d_gate,
const scalar_t* weight1,
const scalar_t* weight2,
scalar_t* output,
const size_t batch_size,
const size_t in_feat,
const size_t hidden_feat,
const size_t out_feat,
const size_t num_expert) {
auto h = getCudaStreamManager(num_expert);
auto cm = getCommManager();
int tot_expert = num_expert * cm->size;
#ifdef MOE_BREAKDOWN
timestamp(t_init);
#endif
scalar_t *local_input_buf, *local_output_buf;
checkCudaErrors(cudaMalloc(&local_input_buf,
sizeof(scalar_t) * batch_size * in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
int* expert_count,
int* d_pos,
const size_t num_expert,
const size_t batch_size) {
int *gate = new int[batch_size];
int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
memset(expert_count, 0, sizeof(int) * tot_expert);
int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
......@@ -108,8 +63,6 @@ void moe_cuda_forward_impl(
}
int *pos = new int[batch_size];
int *d_pos;
checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
......@@ -120,40 +73,11 @@ void moe_cuda_forward_impl(
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
int *all_expert_count = new int[tot_expert];
MPI_Alltoall(expert_count, num_expert, MPI_INT,
all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
int *expert_n = new int[num_expert];
int expert_sz = 0;
for (int i = 0; i < num_expert; ++i) {
expert_n[i] = 0;
for (int j = 0; j < cm->size; ++j) {
expert_n[i] += all_expert_count[j * num_expert + i];
}
expert_sz += expert_n[i];
}
scalar_t *input_buf, *hidden_buf, *output_buf;
if (expert_sz) {
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
}
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn %d time %.3lf us\n",
expert_sz,
getDuration(t_init, t_expert) * 1e6);
#endif
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
local_input_buf);
h->sync(0);
// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));
void moe_cuda_global_scatter() {
if (cm->size > 1) {
if (expert_sz) {
checkCudaErrors(cudaMalloc(&input_buf,
......@@ -192,58 +116,137 @@ void moe_cuda_forward_impl(
input_buf = local_input_buf;
output_buf = local_output_buf;
}
}
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
const int* d_pos,
scalar_t* input_buf,
const size_t batch_size,
const size_t in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
h->sync(0);
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
#ifdef MOE_BREAKDOWN
timestamp(t_scatter);
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
1e6);
#endif
template <typename scalar_t>
void moe_cuda_local_gather_impl(
const scalar_t* output_buf,
const int* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_n[i] == 0) {
continue;
}
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "worker %d gemm %d sz %d offset %d\n", cm->rank, i, expert_n[i], ptr);
// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i),
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
hidden_feat, expert_n[i], in_feat,
out_feat, expert_count[i], in_feat,
&alpha,
weight1 + i * in_feat * hidden_feat, in_feat,
weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat,
&beta,
hidden_buf + hidden_feat * ptr, hidden_feat
output_buf + out_feat * ptr, out_feat
));
checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T,
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
template <typename scalar_t>
void moe_cuda_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
out_feat, expert_n[i], hidden_feat,
CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha,
weight2 + i * hidden_feat * out_feat, hidden_feat,
hidden_buf + hidden_feat * ptr, hidden_feat,
weight + i * in_feat * out_feat, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
output_buf + out_feat * ptr, out_feat
grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
));
ptr += expert_n[i];
}
h->sync();
smgr->sync(num_expert);
}
#ifdef MOE_BREAKDOWN
timestamp(t_mm);
fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
1e6);
#endif
void moe_cuda_global_gather() {
if (cm->size > 1) {
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
......@@ -273,162 +276,149 @@ void moe_cuda_forward_impl(
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
}
#ifdef MOE_BREAKDOWN
h->sync(0);
timestamp(t_gather);
fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
1e6);
#endif
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
local_output_buf, output);
h->sync(0);
#ifdef MOE_BREAKDOWN
timestamp(t_end);
fprintf(stderr, "Local gather %.3lf us\n", getDuration(t_gather, t_end) *
1e6);
fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_end) *
1e6);
#endif
if (expert_sz) {
cudaFree(hidden_buf);
if (cm->size > 1) {
cudaFree(input_buf);
cudaFree(output_buf);
}
}
cudaFree(local_input_buf);
cudaFree(local_output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate,
size_t num_expert) {
const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
moe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
template <typename scalar_t>
void moe_cuda_grad_weight(
const scalar_t* input,
const int* gate,
const scalar_t* grad_output,
scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert) {
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = input.size(0);
const auto in_feat = input.size(1);
auto input_buf = torch::empty_like(input);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda",
([&] {
moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
auto h = getCudaStreamManager(num_expert);
int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handles[0],
CUBLAS_OP_N,
CUBLAS_OP_T,
out_feat,
in_feat,
1,
&alpha,
grad_output + i * out_feat,
out_feat,
input + i * in_feat,
in_feat,
&beta,
grad_weight + gate_host[i] * out_feat * in_feat,
out_feat));
}
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
}
delete[] gate_host;
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = output_buf.size(0);
const auto out_feat = output_buf.size(1);
auto output = torch::empty_like(output_buf);
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] {
moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat,
smgr);
}));
return {output,};
}
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight1,
torch::Tensor weight2
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count
) {
const auto batch_size = input.size(0);
const auto num_expert = weight1.size(0);
const auto out_feat = weight2.size(1);
const auto hidden_feat = weight1.size(1);
const auto in_feat = weight1.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
auto output = input.new_zeros({batch_size, out_feat});
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
moe_cuda_forward_impl<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
batch_size,
in_feat,
hidden_feat,
out_feat,
num_expert
);
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda",
([&] {
moe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<int>(),
output.data_ptr<scalar_t>(),
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {output, };
}
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
const auto batch_size = input.size(0);
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
// grad_input is easy to compute, exactly the same as forward
/* TODO: Backward currently brokenn
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_forward_impl<scalar_t>(
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(),
batch_size,
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
grad_output.data_ptr<scalar_t>(),
expert_count.data_ptr<int>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
batch_size,
in_feat,
out_feat,
num_expert
num_expert,
smgr
);
}));
return {grad_input, grad_weight};
return {grad_input_buf, grad_weight};
}
......
......@@ -4,48 +4,61 @@ import time
import sys
dev_name = 'cuda:0'
def perf():
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = int(sys.argv[1])
io_feat = int(sys.argv[2])
hidden_feat = int(sys.argv[3])
in_feat = int(sys.argv[2])
out_feat = int(sys.argv[3])
num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, io_feat).cuda()
inp = torch.rand(batch_size, io_feat).cuda(dev_name)
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda()
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
moe.train()
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
n_runs = 16
tott = 0.
backt = 0.
maxt = 0.
sqtot = 0.
for i in range(n_runs):
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda()
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
ts = time.time()
o = moe(inp, gate)
te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * io_feat * hidden_feat * 2 * batch_size / tott
print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops))
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
backt * 1e3 / n_runs, gflops))
if __name__ == '__main__':
......
......@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ]
then
python moe.py
python3 moe.py
elif [ .$1 = '.test_all' ]
then
for nexp in 1 2 4
......@@ -20,11 +20,11 @@ then
for bs in 4 16 64 256 512 1024 2048 4096
do
echo $bs $nexp ${inf}x${ouf}
python moe_test.py $bs $inf $ouf $nexp
python3 moe_test.py $bs $inf $ouf $nexp
done
done
done
done
else
python $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
python3 $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
fi
......@@ -9,6 +9,8 @@ import torch.nn as nn
import torch.nn.functional as F
# import torch_sparse
from cuda.moe import MOELayer
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
from log_uniform_sampler import LogUniformSampler, sample_logits
......@@ -31,9 +33,74 @@ class PositionalEmbedding(nn.Module):
else:
return pos_emb[:,None,:]
class MoEPositionwiseFF(nn.Module):
class CustomizedMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=2, num_expert=32):
super(CustomizedMoEPositionwiseFF, self).__init__()
print("CustomizedMoEPositionwiseFF num_expert=%d top_k=%d" % (num_expert, top_k))
self.top_k = top_k
assert num_expert >= top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, num_expert)
self.moe1 = MOELayer(num_expert=num_expert, in_feat=d_model+1, out_feat=d_inner)
self.moe2 = MOELayer(num_expert=num_expert, in_feat=d_inner+1, out_feat=d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
pass
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
core_out = []
inp = inp.view(-1, self.d_model)
inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0)
for i in range(self.top_k):
gate_idx = gate_top_k_idx[:, i].contiguous()
x = self.moe1(inp, gate_idx)
x = self.dropout(F.relu(x))
x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_idx)
x = self.dropout(x) # (BxL) x d_model
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class MoEPositionwiseFFRaw(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFF, self).__init__()
super(MoEPositionwiseFFRaw, self).__init__()
print("MoEPositionwiseFF")
self.top_k = top_k
......@@ -820,7 +887,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -840,7 +907,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -861,7 +928,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
#!/bin/bash
export LD_LIBRARY_PATH=/home/jiezhong/miniconda3/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment