Commit 5ead59db authored by Sengxian's avatar Sengxian
Browse files

Add test for arbitrary expert in FMoE

parent fc78d5c3
...@@ -116,7 +116,7 @@ class FMoE(nn.Module): ...@@ -116,7 +116,7 @@ class FMoE(nn.Module):
if expert_fn is None: if expert_fn is None:
assert expert is not None, 'Either expert or expert_fn should be set' assert expert is not None, 'Either expert or expert_fn should be set'
self.experts = [expert(d_model) for _ in range(num_expert)] self.experts = [expert(d_model) for _ in range(num_expert)]
def expert_fn(self, inp, fwd_expert_count): def expert_fn(inp, fwd_expert_count):
outputs = [] outputs = []
base_idx = 0 base_idx = 0
for i in range(self.num_expert): for i in range(self.num_expert):
......
...@@ -4,16 +4,24 @@ import torch ...@@ -4,16 +4,24 @@ import torch
class BruteForceMoELinear(nn.Module): class BruteForceMoELinear(nn.Module):
def __init__(self, activation, num_expert=32, d_model=1024, world_size=1, top_k=2): def __init__(
self,
activation,
num_expert=32,
d_model=1024,
d_hidden=2048,
world_size=1,
top_k=2,
):
super(BruteForceMoELinear, self).__init__() super(BruteForceMoELinear, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
self.activation = activation self.activation = activation
self.weight_htoh4 = nn.Parameter( self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model * 4, d_model) torch.Tensor(num_expert * world_size, d_hidden, d_model)
) )
self.weight_h4toh = nn.Parameter( self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_model * 4) torch.Tensor(num_expert * world_size, d_model, d_hidden)
) )
self.top_k = top_k self.top_k = top_k
...@@ -29,3 +37,43 @@ class BruteForceMoELinear(nn.Module): ...@@ -29,3 +37,43 @@ class BruteForceMoELinear(nn.Module):
-1, self.d_model -1, self.d_model
) )
return x return x
class BruteForceMoE(nn.Module):
def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
super(BruteForceMoE, self).__init__()
self.num_expert = num_expert
self.d_model = d_model
self.top_k = top_k
self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
x[i] = self.experts[gate_long[i]](inp[i])
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
return x
class NaiveExpert(nn.Module):
def __init__(self, d_model):
super(NaiveExpert, self).__init__()
self.linear = nn.Linear(d_model, d_model).cuda()
def forward(self, x):
return self.linear(x)
class LinearExpert(nn.Module):
def __init__(self, d_model):
super(LinearExpert, self).__init__()
self.model = nn.Sequential(
nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
).cuda()
def forward(self, x):
return self.model(x)
from fmoe.layers import FMoE import json
from fmoe.transformer import _Expert import os
from fmoe.gates import NaiveGate import sys
from typing import List, Callable, Dict, Type, Union
from moe import BruteForceMoELinear import pytest
import torch import torch
import sys import torch.nn as nn
import os
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
rank = 0 rank = 0
world_size = 1 world_size = 1
def test_fmoe_linear(): def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k):
moe.zero_grad()
moe_raw.zero_grad()
inp = torch.rand(batch_size, d_model).cuda()
gate_idx, gate_score = moe.gate(inp)
inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp).mean()
raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()
moe_out.backward()
raw_out.backward()
return moe_out, raw_out
def _assert_numercial(names, moe_out_list, raw_out_list):
for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().sum()
print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3:
sys.stderr.write("=========== moe out ==============\n")
sys.stderr.write("{}\n".format(mo))
sys.stderr.write("=========== raw out ==============\n")
sys.stderr.write("{}\n".format(ro))
assert False
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_linear(
num_expert,
top_k,
batch_size,
d_model,
d_hidden,
activation=torch.nn.functional.gelu,
):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
batch_size = 4
num_expert = 2
d_model = 6
d_hidden = 8
top_k = 2
activation = torch.nn.functional.gelu
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda() experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()
...@@ -40,6 +78,7 @@ def test_fmoe_linear(): ...@@ -40,6 +78,7 @@ def test_fmoe_linear():
activation=activation, activation=activation,
num_expert=num_expert, num_expert=num_expert,
d_model=d_model, d_model=d_model,
d_hidden=d_hidden,
world_size=world_size, world_size=world_size,
).cuda() ).cuda()
...@@ -59,54 +98,119 @@ def test_fmoe_linear(): ...@@ -59,54 +98,119 @@ def test_fmoe_linear():
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data) torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0) moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
inp = torch.rand(batch_size, d_model).cuda() moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
gate_idx, gate_score = moe.gate(inp) moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
print(gate_idx.shape, gate_score.shape) raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp).mean()
raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()
moe_out.backward()
raw_out.backward()
moe_out = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
raw_out = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"]
if world_size > 1: if world_size > 1:
ou, htoh4_grad, h4toh_grad = raw_out _, htoh4_grad, h4toh_grad = raw_out_list
torch.distributed.all_reduce(htoh4_grad) torch.distributed.all_reduce(htoh4_grad)
torch.distributed.all_reduce(h4toh_grad) torch.distributed.all_reduce(h4toh_grad)
htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert]
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert]
raw_out = ou, htoh4_grad, h4toh_grad raw_out_list = _, htoh4_grad, h4toh_grad
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() names = ["output", "htoh4 weight grad", "h4toh weight grad"]
print("Rank {} {} abs err {}".format(rank, name, err)) _assert_numercial(names, moe_out_list, raw_out_list)
if err > 1e-3:
sys.stderr.write("=========== moe out ==============\n")
sys.stderr.write("{}\n".format(mo))
sys.stderr.write("=========== raw out ==============\n")
sys.stderr.write("{}\n".format(ro))
assert False
torch.cuda.synchronize()
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe(
batch_size, num_expert, d_model, top_k, expert: Union[Type[nn.Module], str]
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if isinstance(expert, str):
expert = globals()[expert]
def test_fmoe_linear_distributed(): moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=None,
expert=expert,
top_k=top_k,
).cuda()
moe_raw = BruteForceMoE(
expert=expert, num_expert=num_expert, d_model=d_model, world_size=world_size,
).cuda()
if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
for para_moe, para_raw in zip(
expert_moe.parameters(), expert_raw.parameters()
):
para_raw.data = para_moe.data.clone()
else:
assert len(moe.experts) >= 1
for idx, para in enumerate(moe.experts[0].parameters()):
para_tensor = torch.cat(
[list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
)
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tesnor_gathered = torch.cat(para_array, dim=0)
assert len(para_array) == len(moe_raw.experts)
for expertID in range(para_tesnor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[idx].data = para_tensor[
expertID
]
moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
def get_experts_grad(experts: List[nn.Module]):
return torch.stack(
[
torch.stack(
[
p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
for p in item.parameters()
]
).sum()
for item in experts
]
)
moe_grad, raw_grad = (
get_experts_grad(moe.experts),
get_experts_grad(moe_raw.experts),
)
if world_size > 1:
torch.distributed.all_reduce(raw_grad)
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert]
moe_out_list = [moe_out, moe_grad]
raw_out_list = [raw_out, raw_grad]
names = ["forward", "backward"]
_assert_numercial(names, moe_out_list, raw_out_list)
def _run_distributed(func: Callable, args: Dict):
import subprocess import subprocess
import os import os
ps, n = [], 2
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666" os.environ["MASTER_PORT"] = "36666"
ps, n = [], 2
os.environ["WORLD_SIZE"] = str(n) os.environ["WORLD_SIZE"] = str(n)
for i in range(n): for i in range(n):
os.environ["RANK"] = str(i) os.environ["RANK"] = str(i)
os.environ["CUDA_VISIBLE_DEVICES"] = str(i) os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
p = subprocess.Popen([sys.executable, __file__], stdout=subprocess.PIPE) p = subprocess.Popen(
[sys.executable, __file__, func.__name__, json.dumps(args)],
stdout=subprocess.PIPE,
)
ps.append(p) ps.append(p)
for p in ps: for p in ps:
...@@ -115,11 +219,55 @@ def test_fmoe_linear_distributed(): ...@@ -115,11 +219,55 @@ def test_fmoe_linear_distributed():
assert retc == 0 assert retc == 0
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden,
):
_run_distributed(
test_fmoe_linear,
{
"num_expert": num_expert,
"top_k": top_k,
"batch_size": batch_size,
"d_model": d_model,
"d_hidden": d_hidden,
},
)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe_distributed(
num_expert, top_k, batch_size, d_model, expert,
):
_run_distributed(
test_fmoe,
{
"num_expert": num_expert,
"top_k": top_k,
"batch_size": batch_size,
"d_model": d_model,
"expert": expert,
},
)
if __name__ == "__main__": if __name__ == "__main__":
# os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0") os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
# os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1") os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
if int(os.environ["WORLD_SIZE"]) > 1: if int(os.environ["WORLD_SIZE"]) > 1:
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
test_fmoe_linear() if len(sys.argv) >= 3:
locals()[sys.argv[1]](**json.loads(sys.argv[2]))
else:
test_fmoe_linear(batch_size=4, num_expert=4, d_model=8, top_k=2, d_hidden=16)
test_fmoe(batch_size=4, num_expert=4, d_model=8, top_k=2, expert=NaiveExpert)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment