Commit 069cf01a authored by Rick Ho's avatar Rick Ho
Browse files

make moe run with cuda

parent a4f7f1da
......@@ -67,6 +67,12 @@ std::vector<torch::Tensor> moe_backward(
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_expert_exchange(
torch::Tensor local_expert_count,
size_t num_expert, size_t n_workers) {
return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
}
std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
......@@ -107,6 +113,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
#ifdef MOE_USE_NCCL
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
#endif
......
......@@ -82,6 +82,35 @@ void moe_cuda_expert_count_impl(
#ifdef MOE_USE_NCCL
void moe_cuda_expert_exchange_impl(
const int* local_expert_count,
int* global_expert_count,
int* fwd_expert_count,
int num_expert, int world_size) {
MPI_Alltoall(local_expert_count, num_expert, MPI_INT,
global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < world_size; ++j) {
fwd_expert_count[i] += global_expert_count[i + j * num_expert];
}
}
}
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto fwe_options = torch::TensorOptions()
.dtype(local_expert_count.dtype());
auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
fwd_expert_count.data_ptr<int>(),
num_expert, n_workers);
return {global_expert_count, fwd_expert_count};
}
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
......
......@@ -41,6 +41,10 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers);
#endif
#endif // MOE_CUDA_KERNEL_H
......@@ -35,53 +35,56 @@ class MOEGlobal(Function):
local_expert_count, pos = moe_cuda.expert_count(gate,
world_size * num_expert)
global_expert_count = torch.empty_like(world_size, num_expert)
torch.distributed.all_to_all(global_expert_count,
local_expert_count.reshape(world_size, num_expert))
batch_size = int(global_expert_count.sum().item())
global_expert_count, fwd_expert_count = moe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = moe_cuda.local_scatter(inp, pos)
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
batch_size, world_size)
fwd_batch_size, world_size)
global_output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
global_output_buf, = moe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
inp.shape[0], world_size)
output = moe_cuda.local_gather(local_output_buf, pos)
output, = moe_cuda.local_gather(local_output_buf, pos)
variables = [input_buf, gate, weight,
local_expert_count, global_expert_count,
pos, num_expert, batch_size, world_size]
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos)
ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
ctx.save_for_backward(*variables)
return output[0]
return output
@staticmethod
def backward(ctx, grad_out):
(input_buf, gate, weight, local_expert_count, global_expert_count,
pos, num_expert, batch_size, world_size) = ctx.saved_tensors
(input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos) = ctx.saved_tensors
num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
batch_size, world_size)
fwd_batch_size, world_size)
grad_inp_buf, grad_weight = moe_cuda.backward(
global_grad_out_buf, input_buf, weight, expert_count)
global_grad_out_buf, input_buf, weight, fwd_expert_count)
local_grad_inp_buf = moe_cuda.global_gather(grad_inp_buf,
local_grad_inp_buf, = moe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
batch_size, world_size)
local_batch_size, world_size)
grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight
return grad_inp, None, grad_weight, None
def moe(inp, gate, weight, world_size):
if world_size is not None:
return MOEGlobal.apply(inp, gate, weight)
return MOEGlobal.apply(inp, gate, weight, world_size)
else:
return MOELocal.apply(inp, gate, weight)
......@@ -82,18 +82,23 @@ def test():
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
if world_size > 1:
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
else:
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ),
high=num_expert * world_size,
size=(batch_size,),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
print('hhh')
return
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
......@@ -128,6 +133,9 @@ def test_dp():
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
world_size = torch.distributed.get_world_size()
if world_size == 1:
world_size = None
test()
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
# perf()
......@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ]
then
python3 moe_test.py
python3 moe_test.py 2>logs/$OMPI_COMM_WORLD_RANK.log
elif [ .$1 = '.test_all' ]
then
for nexp in 1 2 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment