Commit a4f7f1da authored by Rick Ho's avatar Rick Ho
Browse files

split function file

parent ec322e4b
import math
from torch import nn
from torch.autograd import Function
import torch
import moe_cuda
class MOEFunction(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
from moe_function import moe
class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=None):
super(MOELayer, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
......@@ -49,7 +23,7 @@ class MOELayer(nn.Module):
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
return MOEFunction.apply(inp, gate.int(), self.weight)
return moe(inp, gate.int(), self.weight, self.world_size)
class MOELayer_raw(nn.Module):
......@@ -64,7 +38,8 @@ class MOELayer_raw(nn.Module):
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data
......@@ -75,73 +50,3 @@ class MOELayer_raw(nn.Module):
for i in range(batch_size):
x[i] = inp[i] @ self.weight[gate_long[i]].t()
return x
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4
num_expert = 2
in_feat = 6
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
test()
# test_dp()
import torch
from torch.autograd import Function
import moe_cuda
class MOELocal(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOEGlobal(Function):
@staticmethod
def forward(ctx, inp, gate, weight, world_size):
num_expert = weight.shape[0]
local_expert_count, pos = moe_cuda.expert_count(gate,
world_size * num_expert)
global_expert_count = torch.empty_like(world_size, num_expert)
torch.distributed.all_to_all(global_expert_count,
local_expert_count.reshape(world_size, num_expert))
batch_size = int(global_expert_count.sum().item())
local_input_buf, = moe_cuda.local_scatter(inp, pos)
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
batch_size, world_size)
global_output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
inp.shape[0], world_size)
output = moe_cuda.local_gather(local_output_buf, pos)
variables = [input_buf, gate, weight,
local_expert_count, global_expert_count,
pos, num_expert, batch_size, world_size]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
(input_buf, gate, weight, local_expert_count, global_expert_count,
pos, num_expert, batch_size, world_size) = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
batch_size, world_size)
grad_inp_buf, grad_weight = moe_cuda.backward(
global_grad_out_buf, input_buf, weight, expert_count)
local_grad_inp_buf = moe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
batch_size, world_size)
grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight
def moe(inp, gate, weight, world_size):
if world_size is not None:
return MOEGlobal.apply(inp, gate, weight)
else:
return MOELocal.apply(inp, gate, weight)
from moe import MOELayer, MOELayer_raw
import torch
from torch import nn
import time
import sys
......@@ -61,7 +62,72 @@ def perf():
backt * 1e3 / n_runs, gflops))
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4
num_expert = 2
in_feat = 6
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
test()
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
perf()
# perf()
......@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ]
then
python3 moe.py
python3 moe_test.py
elif [ .$1 = '.test_all' ]
then
for nexp in 1 2 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment