Commit 22e1eb45 authored by Rick Ho's avatar Rick Ho
Browse files

complete test for reconstruction

parent d2039fc7
from .moe import BruteForceMoE
from .layers import FMoELinear, FMoENaiveGate, FMoETransformerMLP from .layers import FMoELinear, FMoENaiveGate, FMoETransformerMLP
...@@ -3,8 +3,11 @@ from torch.autograd import Function ...@@ -3,8 +3,11 @@ from torch.autograd import Function
import fmoe_cuda import fmoe_cuda
def moe_prepare_forward(gate, num_expert, world_size): def moe_prepare_forward(gate, num_expert, world_size, comm=None):
fmoe_cuda.ensure_nccl(torch.distributed.distributed_c10d._default_pg, gate) if comm is None:
comm = torch.distributed.distributed_c10d._default_pg
if world_size > 1:
fmoe_cuda.ensure_nccl(comm, gate)
with torch.no_grad(): with torch.no_grad():
_, pos = torch.sort(gate) _, pos = torch.sort(gate)
......
...@@ -57,7 +57,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): ...@@ -57,7 +57,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
class FMoETransformerMLP(nn.Module): class FMoETransformerMLP(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096, def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu, world_size=1, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False): top_k=2, pre_lnorm=False):
super(FMoETransformerMLP, self).__init__() super(FMoETransformerMLP, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
......
import torch
from torch.autograd import Function
import fmoe_cuda
class MOELocal(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
expert_count = torch.zeros(weight.shape[0], device=weight.device,
dtype=torch.long)
expert_count.index_put_((gate_idx.long(), ), gate_count)
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc = expert_count.cpu()
input_buf, = fmoe_cuda.local_gather(inp, pos)
output_buf, = fmoe_cuda.forward(input_buf, weight, ecc)
output = fmoe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, ecc, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = fmoe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOEGlobal(Function):
@staticmethod
def forward(ctx, inp, gate, weight, world_size):
fmoe_cuda.ensure_nccl(
torch.distributed.distributed_c10d._default_pg, inp)
num_expert = weight.shape[0]
# local_expert_count, pos = fmoe_cuda.expert_count(gate,
# world_size * num_expert)
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size,
device=weight.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0).cpu()
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = fmoe_cuda.local_gather(inp, pos)
local_expert_count = local_expert_count.cpu()
global_expert_count = global_expert_count.cpu()
local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
local_input_buf, weight,
local_expert_count, global_expert_count,
fwd_batch_size, inp.shape[0], world_size)
output, = fmoe_cuda.local_scatter(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos)
ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
(input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos) = ctx.saved_tensors
num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
grad_inp_buf, grad_weight = fmoe_cuda.backward(
global_grad_out_buf, input_buf, weight, fwd_expert_count)
local_grad_inp_buf, = fmoe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
grad_inp, = fmoe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight, None
def moe(inp, gate, weight, world_size):
if world_size is not None and world_size > 1:
return MOEGlobal.apply(inp, gate, weight, world_size)
else:
return MOELocal.apply(inp, gate, weight)
...@@ -3,28 +3,27 @@ from torch import nn ...@@ -3,28 +3,27 @@ from torch import nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .moe_function import moe from fmoe.layers import FMoELinear, _fmoe_full_forward
class FMoE(nn.Module): class FMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=None): world_size=1):
super(FMoE, self).__init__() super(FMoE, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.world_size = world_size self.world_size = world_size
self.weight = nn.Parameter( self.linear = FMoELinear(num_expert, in_feat, out_feat)
torch.Tensor(num_expert, out_feat, in_feat)) self.weight = self.linear.weight
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): self.linear.reset_parameters()
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
return moe(inp, gate.int(), self.weight, self.world_size) return _fmoe_full_forward(inp, gate, [self.linear], None,
self.num_expert, self.world_size)
class BruteForceMoE(nn.Module): class BruteForceMoE(nn.Module):
......
from fmoe import FMoE as MOELayer from moe import FMoE as MOELayer
from fmoe import BruteForceMoE as MOELayer_raw from moe import BruteForceMoE as MOELayer_raw
import torch import torch
from torch import nn from torch import nn
import time import time
...@@ -82,7 +82,6 @@ def test_module(moe, linear, inp, gate): ...@@ -82,7 +82,6 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad() moe.zero_grad()
x = (linear(inp)) x = (linear(inp))
output = moe(x, gate) output = moe(x, gate)
# print('ooutput', torch.distributed.get_rank(), output)
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
...@@ -102,10 +101,7 @@ def test(): ...@@ -102,10 +101,7 @@ def test():
linear = nn.Linear(in_feat, in_feat).cuda() linear = nn.Linear(in_feat, in_feat).cuda()
if world_size > 1:
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda() moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
else:
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda() moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone() moe_raw.weight.data = moe.weight.data.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment