Commit fc78d5c3 authored by Sengxian's avatar Sengxian
Browse files

Add test for FMoE

parent 103343ca
import math import math
from torch import nn from torch import nn
import torch import torch
import torch.nn.functional as F
from fmoe.layers import FMoELinear, _fmoe_full_forward
class BruteForceMoELinear(nn.Module):
class FMoE(nn.Module): def __init__(self, activation, num_expert=32, d_model=1024, world_size=1, top_k=2):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, super(BruteForceMoELinear, self).__init__()
world_size=1):
super(FMoE, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.linear = FMoELinear(num_expert, in_feat, out_feat)
self.weight = self.linear.weight
self.reset_parameters()
def reset_parameters(self):
self.linear.reset_parameters()
def forward(self, inp, gate):
return _fmoe_full_forward(inp, gate, [self.linear], None,
self.num_expert, self.world_size)
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
super(BruteForceMoE, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.d_model = d_model
self.out_feat = out_feat self.activation = activation
self.weight = nn.Parameter( self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, out_feat, in_feat)) torch.Tensor(num_expert * world_size, d_model * 4, d_model)
self.reset_parameters() )
self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_model * 4)
def reset_parameters(self): )
for i in range(self.num_expert): self.top_k = top_k
linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat) def forward(self, inp, gate_idx, gate_score):
# print(linear.weight.shape) gate_long = gate_idx.long()
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
gate_long = gate.long()
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat)) x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size): for i in range(batch_size):
x[i] = inp[i] @ self.weight[gate_long[i]].t() t = inp[i] @ self.weight_htoh4[gate_long[i]].t()
t = self.activation(t)
x[i] = t @ self.weight_h4toh[gate_long[i]].t()
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
return x return x
from moe import FMoE as MOELayer from fmoe.layers import FMoE
from moe import BruteForceMoE as MOELayer_raw from fmoe.transformer import _Expert
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import torch import torch
from torch import nn
import sys import sys
import os import os
rank = 0
world_size = 1
rank = None
world_size = None
def test_moe():
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test_fmoe_linear():
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
batch_size = 4 batch_size = 4
num_expert = 2 num_expert = 2
in_feat = 6 d_model = 6
out_feat = 7 d_hidden = 8
top_k = 2
activation = torch.nn.functional.gelu
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=None,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
linear = nn.Linear(in_feat, in_feat).cuda() moe_raw = BruteForceMoELinear(
activation=activation,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
).cuda()
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone() moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
else: else:
weight_array = [torch.empty_like(moe.weight.data) weight_htoh4_array = [
for _ in range(world_size)] torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
torch.distributed.all_gather(weight_array, moe.weight.data) ]
moe_raw.weight.data = torch.cat(weight_array, dim=0) torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, weight_h4toh_array = [
high=num_expert * world_size, torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
size=(batch_size,), ]
requires_grad=False).int().cuda() torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda() moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_out = test_module(moe, linear, inp.clone(), gate.clone()) inp = torch.rand(batch_size, d_model).cuda()
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
gate_idx, gate_score = moe.gate(inp)
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] print(gate_idx.shape, gate_score.shape)
inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp).mean()
raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()
moe_out.backward()
raw_out.backward()
moe_out = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
raw_out = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"]
if world_size > 1: if world_size > 1:
ou, wg, lwg, lbg = raw_out ou, htoh4_grad, h4toh_grad = raw_out
torch.distributed.all_reduce(wg) torch.distributed.all_reduce(htoh4_grad)
wg = wg[rank * num_expert:(rank + 1)* num_expert] torch.distributed.all_reduce(h4toh_grad)
raw_out = ou, wg, lwg, lbg htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert]
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert]
raw_out = ou, htoh4_grad, h4toh_grad
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('Rank {} {} abs err {}'.format(rank, name, err)) print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3: if err > 1e-3:
sys.stderr.write('=========== moe out ==============\n') sys.stderr.write("=========== moe out ==============\n")
sys.stderr.write('{}\n'.format(mo)) sys.stderr.write("{}\n".format(mo))
sys.stderr.write('=========== raw out ==============\n') sys.stderr.write("=========== raw out ==============\n")
sys.stderr.write('{}\n'.format(ro)) sys.stderr.write("{}\n".format(ro))
return assert False
torch.cuda.synchronize()
if __name__ == '__main__':
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0') def test_fmoe_linear_distributed():
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1') import subprocess
if int(os.environ['WORLD_SIZE']) > 1: import os
torch.distributed.init_process_group(backend='nccl')
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666"
ps, n = [], 2
os.environ["WORLD_SIZE"] = str(n)
for i in range(n):
os.environ["RANK"] = str(i)
os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
p = subprocess.Popen([sys.executable, __file__], stdout=subprocess.PIPE)
ps.append(p)
for p in ps:
p.wait()
retc = p.poll()
assert retc == 0
if __name__ == "__main__":
# os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
# os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
if int(os.environ["WORLD_SIZE"]) > 1:
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
else: test_fmoe_linear()
rank = 0
world_size = 1
test_moe()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment