Commit 22e1eb45 authored by Rick Ho's avatar Rick Ho
Browse files

complete test for reconstruction

parent d2039fc7
from .moe import BruteForceMoE
from .layers import FMoELinear, FMoENaiveGate, FMoETransformerMLP
......@@ -3,8 +3,11 @@ from torch.autograd import Function
import fmoe_cuda
def moe_prepare_forward(gate, num_expert, world_size):
fmoe_cuda.ensure_nccl(torch.distributed.distributed_c10d._default_pg, gate)
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
if comm is None:
comm = torch.distributed.distributed_c10d._default_pg
if world_size > 1:
fmoe_cuda.ensure_nccl(comm, gate)
with torch.no_grad():
_, pos = torch.sort(gate)
......
......@@ -57,7 +57,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
class FMoETransformerMLP(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
world_size=1, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FMoETransformerMLP, self).__init__()
self.num_expert = num_expert
......
import torch
from torch.autograd import Function
import fmoe_cuda
class MOELocal(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
expert_count = torch.zeros(weight.shape[0], device=weight.device,
dtype=torch.long)
expert_count.index_put_((gate_idx.long(), ), gate_count)
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc = expert_count.cpu()
input_buf, = fmoe_cuda.local_gather(inp, pos)
output_buf, = fmoe_cuda.forward(input_buf, weight, ecc)
output = fmoe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, ecc, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = fmoe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOEGlobal(Function):
@staticmethod
def forward(ctx, inp, gate, weight, world_size):
fmoe_cuda.ensure_nccl(
torch.distributed.distributed_c10d._default_pg, inp)
num_expert = weight.shape[0]
# local_expert_count, pos = fmoe_cuda.expert_count(gate,
# world_size * num_expert)
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size,
device=weight.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0).cpu()
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = fmoe_cuda.local_gather(inp, pos)
local_expert_count = local_expert_count.cpu()
global_expert_count = global_expert_count.cpu()
local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
local_input_buf, weight,
local_expert_count, global_expert_count,
fwd_batch_size, inp.shape[0], world_size)
output, = fmoe_cuda.local_scatter(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos)
ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
(input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos) = ctx.saved_tensors
num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
grad_inp_buf, grad_weight = fmoe_cuda.backward(
global_grad_out_buf, input_buf, weight, fwd_expert_count)
local_grad_inp_buf, = fmoe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
grad_inp, = fmoe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight, None
def moe(inp, gate, weight, world_size):
if world_size is not None and world_size > 1:
return MOEGlobal.apply(inp, gate, weight, world_size)
else:
return MOELocal.apply(inp, gate, weight)
......@@ -3,28 +3,27 @@ from torch import nn
import torch
import torch.nn.functional as F
from .moe_function import moe
from fmoe.layers import FMoELinear, _fmoe_full_forward
class FMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=None):
world_size=1):
super(FMoE, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.linear = FMoELinear(num_expert, in_feat, out_feat)
self.weight = self.linear.weight
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data
self.linear.reset_parameters()
def forward(self, inp, gate):
return moe(inp, gate.int(), self.weight, self.world_size)
return _fmoe_full_forward(inp, gate, [self.linear], None,
self.num_expert, self.world_size)
class BruteForceMoE(nn.Module):
......
from fmoe import FMoE as MOELayer
from fmoe import BruteForceMoE as MOELayer_raw
from moe import FMoE as MOELayer
from moe import BruteForceMoE as MOELayer_raw
import torch
from torch import nn
import time
......@@ -82,7 +82,6 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
# print('ooutput', torch.distributed.get_rank(), output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
......@@ -102,10 +101,7 @@ def test():
linear = nn.Linear(in_feat, in_feat).cuda()
if world_size > 1:
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
else:
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment