Commit 307e0ad9 authored by Rick Ho's avatar Rick Ho
Browse files

Merge branch 'laekov/batching' into laekov/multigpu

parents cbd86de8 861b75c1
#include <cuda_runtime.h> #include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#include <helper_cuda.h>
CudaStreamManager* smgr = NULL; #define SMGR_N_STREAMS 4
CudaStreamManager* getCudaStreamManager(const size_t num_expert) { cudaStream_t CudaStreamManager::stream(size_t idx) {
if (!smgr) { return this->streams[idx % SMGR_N_STREAMS];
smgr = new CudaStreamManager(num_expert);
}
return smgr;
} }
void CudaStreamManager::sync(int i) { cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (i > -1) { return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
return;
} }
for (size_t i=0; i<MAX_STREAMS; ++i) { }
cudaStreamSynchronize(streams[i]);
void CudaStreamManager::setup(const int device) {
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
} }
delete[] streams;
delete[] handles;
} }
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
...@@ -3,44 +3,30 @@ ...@@ -3,44 +3,30 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h>
class CudaStreamManager {
#define MAX_STREAMS 16 public:
int device;
struct CudaStreamManager {
const size_t num_expert;
cublasHandle_t* handles; cublasHandle_t* handles;
cudaStream_t* streams; cudaStream_t* streams;
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) { public:
streams = new cudaStream_t[MAX_STREAMS]; CudaStreamManager(int device_): device(device_) {
handles = new cublasHandle_t[MAX_STREAMS]; this->setup(device);
for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cublasCreate(handles + i));
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasSetStream(handles[i], streams[i]));
}
} }
~CudaStreamManager() { void setup(int);
for (size_t i=0; i<MAX_STREAMS; ++i) { void sync(int=0);
checkCudaErrors(cudaStreamDestroy(streams[i])); void destroy();
checkCudaErrors(cublasDestroy(handles[i]));
}
}
inline cudaStream_t& getStream(int idx) { cudaStream_t stream(size_t=0);
return streams[idx % MAX_STREAMS]; cublasHandle_t handle(size_t=0);
}
inline cublasHandle_t& getHandle(int idx) {
return handles[idx % MAX_STREAMS];
}
void sync(int=-1); ~CudaStreamManager() {
this->destroy();
}
}; };
CudaStreamManager* getCudaStreamManager(const size_t num_expert); CudaStreamManager* getCudaStreamManager(const int device);
#endif // CUDA_STREAM_MANAGER #endif // CUDA_STREAM_MANAGER
...@@ -4,17 +4,27 @@ ...@@ -4,17 +4,27 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor pos);
torch::Tensor weight1,
torch::Tensor weight2); std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, torch::Tensor grad_output_buf,
torch::Tensor input, torch::Tensor input_buf,
torch::Tensor gate, torch::Tensor weight,
torch::Tensor weight); torch::Tensor expert_count);
// C++ interface // C++ interface
...@@ -23,40 +33,58 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -23,40 +33,58 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate,
size_t num_expert) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert);
}
std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos);
}
std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos);
}
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight1, // [num_expert x hidden_feat x in_feat] torch::Tensor expert_count // [batch_size]
torch::Tensor weight2 // [num_expert x out_feat x hidden_feat]
) { ) {
CHECK_INPUT(input); CHECK_INPUT(input_buf);
CHECK_INPUT(gate); CHECK_INPUT(weight);
CHECK_INPUT(weight1);
CHECK_INPUT(weight2);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_forward(input, gate, weight1, weight2); return moe_cuda_forward(input_buf, weight, expert_count);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count
) { ) {
CHECK_INPUT(grad_output); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input); CHECK_INPUT(input_buf);
CHECK_INPUT(gate);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_backward(grad_output, input, gate, weight); return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
} }
...@@ -72,6 +100,9 @@ int main() { ...@@ -72,6 +100,9 @@ int main() {
*/ */
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)"); m.def("backward", &moe_backward, "MoE backward (CUDA)");
} }
...@@ -5,86 +5,75 @@ import torch ...@@ -5,86 +5,75 @@ import torch
import moe_cuda import moe_cuda
torch.manual_seed(42)
torch.cuda.manual_seed(42)
class MOEFunction(Function): class MOEFunction(Function):
@staticmethod @staticmethod
def forward(ctx, inp, gate, weight1, weight2): def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:] # out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output = moe_cuda.forward(inp, gate, weight1, weight2) expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
variables = [inp, gate, weight1, weight2] input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output[0] return output[0]
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
# print("grad_out", grad_out) input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
# print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward( grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_out.contiguous(), *ctx.saved_tensors) grad_inp_buf, grad_weight = moe_cuda.backward(
out_feat, in_feat = grad_weight.size()[1:] grad_out_buf, input_buf, weight, expert_count)
# print("grad_weight_column_major", grad_weight.flatten()) grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major return grad_inp, None, grad_weight
class MOELayer(nn.Module): class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(MOELayer, self).__init__() super(MOELayer, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight1 = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(num_expert, hidden_feat, in_feat)) torch.Tensor(num_expert, out_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight1.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
return MOEFunction.apply(inp, gate, self.weight1, self.weight2) return MOEFunction.apply(inp, gate.int(), self.weight)
class MOELayer_raw(nn.Module): class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(MOELayer_raw, self).__init__() super(MOELayer_raw, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight1 = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(num_expert, hidden_feat, in_feat)) torch.Tensor(num_expert, out_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
# print(linear.weight.shape) # print(linear.weight.shape)
self.weight1.data[i] = (linear.weight.data) self.weight.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = (linear.weight.data)
def forward(self, inp, gate): def forward(self, inp, gate):
gate_long = gate.long() gate_long = gate.long()
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat)) x = inp.new_zeros((batch_size, self.out_feat))
# print(self.weight2)
for i in range(batch_size): for i in range(batch_size):
hid = inp[i] @ self.weight1[gate_long[i]].t() x[i] = inp[i] @ self.weight[gate_long[i]].t()
# print(hid)
x[i] = hid @ self.weight2[gate_long[i]].t()
return x return x
...@@ -93,28 +82,24 @@ def test_module(moe, linear, inp, gate): ...@@ -93,28 +82,24 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad() moe.zero_grad()
x = (linear(inp)) x = (linear(inp))
output = moe(x, gate) output = moe(x, gate)
# print(output)
if torch.distributed.get_rank() == 1:
print(output)
return output
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test(): def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4 batch_size = 4
num_expert = 2 num_expert = 2
in_feat = 6 in_feat = 6
hidden_feat = 12
out_feat = 7 out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda() linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, hidden_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, hidden_feat, out_feat).cuda() moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight1.data = moe.weight1.data.clone() moe_raw.weight.data = moe.weight.data.clone()
moe_raw.weight2.data = moe.weight2.data.clone()
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
...@@ -124,11 +109,36 @@ def test(): ...@@ -124,11 +109,36 @@ def test():
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
names = ['Out']
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err)) print('{} abs err {}'.format(name, err))
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi') torch.distributed.init_process_group(backend='mpi')
test() test()
# test_dp()
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <mpi.h> #include <mpi.h>
...@@ -20,10 +20,6 @@ ...@@ -20,10 +20,6 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_DEBUG
#define MOE_BREAKDOWN
// #define MOE_DEBUG_SCATTER
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
...@@ -34,10 +30,9 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, ...@@ -34,10 +30,9 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
} }
} }
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_scatter_kernel(int wid, int* pos, void batch_scatter_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) { const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x; inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x]; oubuf += wid * pos[blockIdx.x];
...@@ -46,55 +41,15 @@ void batch_scatter_kernel(int wid, int* pos, ...@@ -46,55 +41,15 @@ void batch_scatter_kernel(int wid, int* pos,
} }
} }
template <typename scalar_t> void moe_cuda_expert_count_impl(
__global__
void batch_gather_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
scalar_t print_first_float(scalar_t* d_ptr) {
scalar_t v;
cudaMemcpy(&v, d_ptr, sizeof(scalar_t), cudaMemcpyDeviceToHost);
return v;
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input,
const int* d_gate, const int* d_gate,
const scalar_t* weight1, int* expert_count,
const scalar_t* weight2, int* d_pos,
scalar_t* output, const size_t num_expert,
const size_t batch_size, const size_t batch_size) {
const size_t in_feat,
const size_t hidden_feat,
const size_t out_feat,
const size_t num_expert) {
auto h = getCudaStreamManager(num_expert);
auto cm = getCommManager();
int tot_expert = num_expert * cm->size;
#ifdef MOE_BREAKDOWN
timestamp(t_init);
#endif
scalar_t *local_input_buf, *local_output_buf;
checkCudaErrors(cudaMalloc(&local_input_buf,
sizeof(scalar_t) * batch_size * in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert]; int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * tot_expert); memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
...@@ -108,8 +63,6 @@ void moe_cuda_forward_impl( ...@@ -108,8 +63,6 @@ void moe_cuda_forward_impl(
} }
int *pos = new int[batch_size]; int *pos = new int[batch_size];
int *d_pos;
checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++; pos[i] = expert_ptr[gate[i]]++;
...@@ -120,40 +73,11 @@ void moe_cuda_forward_impl( ...@@ -120,40 +73,11 @@ void moe_cuda_forward_impl(
expert_ptr[0] = 0; expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
int *all_expert_count = new int[tot_expert]; void moe_cuda_global_scatter() {
MPI_Alltoall(expert_count, num_expert, MPI_INT,
all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
int *expert_n = new int[num_expert];
int expert_sz = 0;
for (int i = 0; i < num_expert; ++i) {
expert_n[i] = 0;
for (int j = 0; j < cm->size; ++j) {
expert_n[i] += all_expert_count[j * num_expert + i];
}
expert_sz += expert_n[i];
}
scalar_t *input_buf, *hidden_buf, *output_buf;
if (expert_sz) {
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
}
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn %d time %.3lf us\n",
expert_sz,
getDuration(t_init, t_expert) * 1e6);
#endif
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
local_input_buf);
h->sync(0);
// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));
if (cm->size > 1) { if (cm->size > 1) {
if (expert_sz) { if (expert_sz) {
checkCudaErrors(cudaMalloc(&input_buf, checkCudaErrors(cudaMalloc(&input_buf,
...@@ -192,58 +116,137 @@ void moe_cuda_forward_impl( ...@@ -192,58 +116,137 @@ void moe_cuda_forward_impl(
input_buf = local_input_buf; input_buf = local_input_buf;
output_buf = local_output_buf; output_buf = local_output_buf;
} }
}
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
const int* d_pos,
scalar_t* input_buf,
const size_t batch_size,
const size_t in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
h->sync(0); template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
#ifdef MOE_BREAKDOWN template <typename scalar_t>
timestamp(t_scatter); void moe_cuda_local_gather_impl(
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) * const scalar_t* output_buf,
1e6); const int* d_pos,
#endif scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_n[i] == 0) { if (expert_n[i] == 0) {
continue; continue;
} }
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "worker %d gemm %d sz %d offset %d\n", cm->rank, i, expert_n[i], ptr);
// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i), checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
hidden_feat, expert_n[i], in_feat, out_feat, expert_count[i], in_feat,
&alpha, &alpha,
weight1 + i * in_feat * hidden_feat, in_feat, weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat, input_buf + ptr * in_feat, in_feat,
&beta, &beta,
hidden_buf + hidden_feat * ptr, hidden_feat output_buf + out_feat * ptr, out_feat
)); ));
checkCudaErrors(cublasXgemm(h->getHandle(i), ptr += expert_count[i];
CUBLAS_OP_T, }
smgr->sync(num_expert);
}
template <typename scalar_t>
void moe_cuda_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_n[i], hidden_feat, CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha, &alpha,
weight2 + i * hidden_feat * out_feat, hidden_feat, weight + i * in_feat * out_feat, in_feat,
hidden_buf + hidden_feat * ptr, hidden_feat, grad_output_buf + ptr * out_feat, out_feat,
&beta, &beta,
output_buf + out_feat * ptr, out_feat grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
)); ));
ptr += expert_n[i]; ptr += expert_n[i];
} }
h->sync(); smgr->sync(num_expert);
}
#ifdef MOE_BREAKDOWN
timestamp(t_mm);
fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
1e6);
#endif
void moe_cuda_global_gather() {
if (cm->size > 1) { if (cm->size > 1) {
int send_ptr = 0; int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) { for (int i = 0; i < num_expert; ++i) {
...@@ -273,162 +276,149 @@ void moe_cuda_forward_impl( ...@@ -273,162 +276,149 @@ void moe_cuda_forward_impl(
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
} }
}
#ifdef MOE_BREAKDOWN std::vector<torch::Tensor> moe_cuda_expert_count(
h->sync(0); torch::Tensor gate,
timestamp(t_gather); size_t num_expert) {
fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) * const auto batch_size = gate.size(0);
1e6);
#endif auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
batch_gather_kernel<scalar_t> auto expert_count = torch::empty(num_expert, ec_options);
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
local_output_buf, output); auto pos_options = torch::TensorOptions()
h->sync(0); .device(gate.device())
.dtype(torch::kInt32);
#ifdef MOE_BREAKDOWN auto pos = torch::empty(batch_size, pos_options);
timestamp(t_end); moe_cuda_expert_count_impl(
fprintf(stderr, "Local gather %.3lf us\n", getDuration(t_gather, t_end) * gate.data_ptr<int>(),
1e6); expert_count.data_ptr<int>(),
fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_end) * pos.data_ptr<int>(),
1e6); num_expert,
#endif batch_size);
if (expert_sz) { return {expert_count, pos};
cudaFree(hidden_buf);
if (cm->size > 1) {
cudaFree(input_buf);
cudaFree(output_buf);
}
}
cudaFree(local_input_buf);
cudaFree(local_output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
} }
template <typename scalar_t> std::vector<torch::Tensor> moe_cuda_local_scatter(
void moe_cuda_grad_weight( torch::Tensor input,
const scalar_t* input, torch::Tensor pos) {
const int* gate, auto smgr = getCudaStreamManager(input.device().index());
const scalar_t* grad_output, const auto batch_size = input.size(0);
scalar_t* grad_weight, // [num_expert x out_feat x in_feat] const auto in_feat = input.size(1);
const size_t batch_size,
const size_t in_feat, auto input_buf = torch::empty_like(input);
const size_t out_feat,
const size_t num_expert) { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda",
([&] {
moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<int>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
auto h = getCudaStreamManager(num_expert); std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
int* gate_host = new int[batch_size]; torch::Tensor pos) {
scalar_t alpha = 1, beta = 1; auto smgr = getCudaStreamManager(output_buf.device().index());
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); const auto batch_size = output_buf.size(0);
for (size_t i=0; i<batch_size; ++i) { const auto out_feat = output_buf.size(1);
checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handles[0], auto output = torch::empty_like(output_buf);
CUBLAS_OP_N,
CUBLAS_OP_T, AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda",
out_feat, ([&] {
in_feat, moe_cuda_local_gather_impl<scalar_t>(
1, output_buf.data_ptr<scalar_t>(),
&alpha, pos.data_ptr<int>(),
grad_output + i * out_feat, output.data_ptr<scalar_t>(),
out_feat, batch_size,
input + i * in_feat, out_feat,
in_feat, smgr);
&beta, }));
grad_weight + gate_host[i] * out_feat * in_feat, return {output,};
out_feat));
}
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
}
delete[] gate_host;
} }
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, torch::Tensor input_buf,
torch::Tensor gate, torch::Tensor weight,
torch::Tensor weight1, torch::Tensor expert_count
torch::Tensor weight2
) { ) {
const auto batch_size = input.size(0); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto num_expert = weight1.size(0); const auto batch_size = input_buf.size(0);
const auto out_feat = weight2.size(1); const auto num_expert = weight.size(0);
const auto hidden_feat = weight1.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight1.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat); printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif #endif
auto output = input.new_zeros({batch_size, out_feat}); auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda",
moe_cuda_forward_impl<scalar_t>( ([&] {
input.data_ptr<scalar_t>(), moe_cuda_forward_impl<scalar_t>(
gate.data_ptr<int>(), input_buf.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(), expert_count.data_ptr<int>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size, in_feat,
in_feat, out_feat,
hidden_feat, num_expert,
out_feat, smgr
num_expert );
);
})); }));
return {output, }; return {output, };
} }
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count
) { ) {
const auto batch_size = input.size(0); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif #endif
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
// grad_input is easy to compute, exactly the same as forward AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
/* TODO: Backward currently brokenn moe_cuda_backward_impl<scalar_t>(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] { grad_output_buf.data_ptr<scalar_t>(),
moe_cuda_forward_impl<scalar_t>( input_buf.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), expert_count.data_ptr<int>(),
batch_size, grad_input_buf.data_ptr<scalar_t>(),
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
grad_output.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert num_expert,
smgr
); );
})); }));
return {grad_input, grad_weight}; return {grad_input_buf, grad_weight};
} }
......
...@@ -4,48 +4,61 @@ import time ...@@ -4,48 +4,61 @@ import time
import sys import sys
dev_name = 'cuda:0'
def perf(): def perf():
torch.manual_seed(42 + torch.distributed.get_rank()) torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank()) torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = int(sys.argv[1]) batch_size = int(sys.argv[1])
io_feat = int(sys.argv[2]) in_feat = int(sys.argv[2])
hidden_feat = int(sys.argv[3]) out_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, io_feat).cuda() inp = torch.rand(batch_size, io_feat).cuda(dev_name)
gate = torch.randint(low=0, gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(), high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda() size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
moe.train()
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
n_runs = 16 n_runs = 16
tott = 0. tott = 0.
backt = 0.
maxt = 0. maxt = 0.
sqtot = 0. sqtot = 0.
for i in range(n_runs): for i in range(n_runs):
gate = torch.randint(low=0, gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(), high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda() size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
ts = time.time() ts = time.time()
o = moe(inp, gate) o = moe(inp, gate)
te = time.time() te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts tott += te - ts
sqtot += (te - ts)**2 sqtot += (te - ts)**2
maxt = max(maxt, te - ts) maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * io_feat * hidden_feat * 2 * batch_size / tott gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format( print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3, tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops)) (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
backt * 1e3 / n_runs, gflops))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7 ...@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ] if [ -z $1 ]
then then
python moe.py python3 moe.py
elif [ .$1 = '.test_all' ] elif [ .$1 = '.test_all' ]
then then
for nexp in 1 2 4 for nexp in 1 2 4
...@@ -20,11 +20,11 @@ then ...@@ -20,11 +20,11 @@ then
for bs in 4 16 64 256 512 1024 2048 4096 for bs in 4 16 64 256 512 1024 2048 4096
do do
echo $bs $nexp ${inf}x${ouf} echo $bs $nexp ${inf}x${ouf}
python moe_test.py $bs $inf $ouf $nexp python3 moe_test.py $bs $inf $ouf $nexp
done done
done done
done done
done done
else else
python $@ 2>logs/$OMPI_COMM_WORLD_RANK.log python3 $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
fi fi
...@@ -9,6 +9,8 @@ import torch.nn as nn ...@@ -9,6 +9,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# import torch_sparse # import torch_sparse
from cuda.moe import MOELayer
sys.path.append('utils') sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
from log_uniform_sampler import LogUniformSampler, sample_logits from log_uniform_sampler import LogUniformSampler, sample_logits
...@@ -31,9 +33,74 @@ class PositionalEmbedding(nn.Module): ...@@ -31,9 +33,74 @@ class PositionalEmbedding(nn.Module):
else: else:
return pos_emb[:,None,:] return pos_emb[:,None,:]
class MoEPositionwiseFF(nn.Module): class CustomizedMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=2, num_expert=32):
super(CustomizedMoEPositionwiseFF, self).__init__()
print("CustomizedMoEPositionwiseFF num_expert=%d top_k=%d" % (num_expert, top_k))
self.top_k = top_k
assert num_expert >= top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, num_expert)
self.moe1 = MOELayer(num_expert=num_expert, in_feat=d_model+1, out_feat=d_inner)
self.moe2 = MOELayer(num_expert=num_expert, in_feat=d_inner+1, out_feat=d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
pass
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
core_out = []
inp = inp.view(-1, self.d_model)
inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0)
for i in range(self.top_k):
gate_idx = gate_top_k_idx[:, i].contiguous()
x = self.moe1(inp, gate_idx)
x = self.dropout(F.relu(x))
x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_idx)
x = self.dropout(x) # (BxL) x d_model
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class MoEPositionwiseFFRaw(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFF, self).__init__() super(MoEPositionwiseFFRaw, self).__init__()
print("MoEPositionwiseFF") print("MoEPositionwiseFF")
self.top_k = top_k self.top_k = top_k
...@@ -820,7 +887,7 @@ class DecoderLayer(nn.Module): ...@@ -820,7 +887,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -840,7 +907,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -840,7 +907,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -861,7 +928,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -861,7 +928,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
#!/bin/bash #!/bin/bash
export LD_LIBRARY_PATH=/home/jiezhong/miniconda3/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
if [[ $1 == 'train' ]]; then if [[ $1 == 'train' ]]; then
echo 'Run training...' echo 'Run training...'
python train.py \ python train.py \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment