Commit fc78d5c3 authored by Sengxian's avatar Sengxian
Browse files

Add test for FMoE

parent 103343ca
import math
from torch import nn
import torch
import torch.nn.functional as F
from fmoe.layers import FMoELinear, _fmoe_full_forward
class FMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=1):
super(FMoE, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.linear = FMoELinear(num_expert, in_feat, out_feat)
self.weight = self.linear.weight
self.reset_parameters()
def reset_parameters(self):
self.linear.reset_parameters()
def forward(self, inp, gate):
return _fmoe_full_forward(inp, gate, [self.linear], None,
self.num_expert, self.world_size)
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
super(BruteForceMoE, self).__init__()
class BruteForceMoELinear(nn.Module):
def __init__(self, activation, num_expert=32, d_model=1024, world_size=1, top_k=2):
super(BruteForceMoELinear, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.weight = nn.Parameter(
torch.Tensor(num_expert * world_size, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
gate_long = gate.long()
self.d_model = d_model
self.activation = activation
self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model * 4, d_model)
)
self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_model * 4)
)
self.top_k = top_k
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
x[i] = inp[i] @ self.weight[gate_long[i]].t()
t = inp[i] @ self.weight_htoh4[gate_long[i]].t()
t = self.activation(t)
x[i] = t @ self.weight_h4toh[gate_long[i]].t()
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
return x
from moe import FMoE as MOELayer
from moe import BruteForceMoE as MOELayer_raw
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import torch
from torch import nn
import sys
import os
rank = 0
world_size = 1
rank = None
world_size = None
def test_moe():
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test_fmoe_linear():
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
batch_size = 4
num_expert = 2
in_feat = 6
out_feat = 7
d_model = 6
d_hidden = 8
top_k = 2
activation = torch.nn.functional.gelu
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=None,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
linear = nn.Linear(in_feat, in_feat).cuda()
moe_raw = BruteForceMoELinear(
activation=activation,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
).cuda()
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone()
moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
else:
weight_array = [torch.empty_like(moe.weight.data)
for _ in range(world_size)]
torch.distributed.all_gather(weight_array, moe.weight.data)
moe_raw.weight.data = torch.cat(weight_array, dim=0)
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
high=num_expert * world_size,
size=(batch_size,),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
weight_htoh4_array = [
torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
weight_h4toh_array = [
torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
inp = torch.rand(batch_size, d_model).cuda()
gate_idx, gate_score = moe.gate(inp)
print(gate_idx.shape, gate_score.shape)
inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp).mean()
raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()
moe_out.backward()
raw_out.backward()
moe_out = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
raw_out = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"]
if world_size > 1:
ou, wg, lwg, lbg = raw_out
torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg, lwg, lbg
ou, htoh4_grad, h4toh_grad = raw_out
torch.distributed.all_reduce(htoh4_grad)
torch.distributed.all_reduce(h4toh_grad)
htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert]
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert]
raw_out = ou, htoh4_grad, h4toh_grad
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('Rank {} {} abs err {}'.format(rank, name, err))
print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3:
sys.stderr.write('=========== moe out ==============\n')
sys.stderr.write('{}\n'.format(mo))
sys.stderr.write('=========== raw out ==============\n')
sys.stderr.write('{}\n'.format(ro))
return
if __name__ == '__main__':
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0')
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1')
if int(os.environ['WORLD_SIZE']) > 1:
torch.distributed.init_process_group(backend='nccl')
sys.stderr.write("=========== moe out ==============\n")
sys.stderr.write("{}\n".format(mo))
sys.stderr.write("=========== raw out ==============\n")
sys.stderr.write("{}\n".format(ro))
assert False
torch.cuda.synchronize()
def test_fmoe_linear_distributed():
import subprocess
import os
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666"
ps, n = [], 2
os.environ["WORLD_SIZE"] = str(n)
for i in range(n):
os.environ["RANK"] = str(i)
os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
p = subprocess.Popen([sys.executable, __file__], stdout=subprocess.PIPE)
ps.append(p)
for p in ps:
p.wait()
retc = p.poll()
assert retc == 0
if __name__ == "__main__":
# os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
# os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
if int(os.environ["WORLD_SIZE"]) > 1:
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1
test_moe()
test_fmoe_linear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment