Commit dad99892 authored by Sengxian's avatar Sengxian
Browse files

Add model parallel test for FMoE with ddp

parent 06e75b3a
...@@ -10,19 +10,19 @@ from test_numerical import test_fmoe as _test_fmoe ...@@ -10,19 +10,19 @@ from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear from test_numerical import test_fmoe_linear as _test_fmoe_linear
def _run_distributed(func, args: Dict): def _run_distributed(func, world_size, args: Dict):
import subprocess import subprocess
import os import os
ps, n = [], 2 ps = []
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666" os.environ["MASTER_PORT"] = "36666"
os.environ["OMPI_COMM_WORLD_SIZE"] = str(n) os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size)
for i in range(n): for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i) os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
p = subprocess.Popen( p = subprocess.Popen(
[sys.executable, __file__, func, json.dumps(args)], stdout=subprocess.PIPE, [sys.executable, __file__, func, json.dumps(args)], stdout=subprocess.PIPE
) )
ps.append(p) ps.append(p)
...@@ -37,17 +37,20 @@ def _run_distributed(func, args: Dict): ...@@ -37,17 +37,20 @@ def _run_distributed(func, args: Dict):
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32]) @pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_linear_distributed( def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, num_expert, top_k, batch_size, d_model, d_hidden, mp_size
): ):
_run_distributed( _run_distributed(
"_test_fmoe_linear", "_test_fmoe_linear",
mp_size * 2,
{ {
"num_expert": num_expert, "num_expert": num_expert,
"top_k": top_k, "top_k": top_k,
"batch_size": batch_size, "batch_size": batch_size,
"d_model": d_model, "d_model": d_model,
"d_hidden": d_hidden, "d_hidden": d_hidden,
"mp_size": mp_size,
}, },
) )
...@@ -57,17 +60,18 @@ def test_fmoe_linear_distributed( ...@@ -57,17 +60,18 @@ def test_fmoe_linear_distributed(
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"]) @pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe_distributed( @pytest.mark.parametrize("mp_size", [1, 2])
num_expert, top_k, batch_size, d_model, expert, def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size):
):
_run_distributed( _run_distributed(
"_test_fmoe", "_test_fmoe",
mp_size * 2,
{ {
"num_expert": num_expert, "num_expert": num_expert,
"top_k": top_k, "top_k": top_k,
"batch_size": batch_size, "batch_size": batch_size,
"d_model": d_model, "d_model": d_model,
"expert": expert, "expert": expert,
"mp_size": mp_size,
}, },
) )
...@@ -81,4 +85,20 @@ if __name__ == "__main__": ...@@ -81,4 +85,20 @@ if __name__ == "__main__":
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
args["rank"] = torch.distributed.get_rank() args["rank"] = torch.distributed.get_rank()
args["world_size"] = torch.distributed.get_world_size() args["world_size"] = torch.distributed.get_world_size()
args["mp_group"] = (
[
torch.distributed.new_group(
ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
backend="nccl",
)
for j in range(args["world_size"] // args["mp_size"])
][args["rank"] // args["mp_size"]]
if args["mp_size"] > 1
else None
)
del args["mp_size"]
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
else:
test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2
)
import json
import os
import sys import sys
from typing import List, Callable, Dict, Type, Union from typing import List, Type, Union
import pytest import pytest
import torch import torch
...@@ -13,10 +11,24 @@ from fmoe.transformer import _Expert ...@@ -13,10 +11,24 @@ from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k): def _perform_forward(
moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group
):
moe.zero_grad() moe.zero_grad()
moe_raw.zero_grad() moe_raw.zero_grad()
inp = torch.rand(batch_size, d_model).cuda() if not mp_group:
inp = torch.rand(batch_size, d_model).cuda()
else:
group_sender = rank // mp_group.size() * mp_group.size()
inp = torch.rand(batch_size, d_model).cuda()
torch.distributed.broadcast(inp, group_sender, group=mp_group)
torch.distributed.broadcast(
moe.gate.gate.weight.data, group_sender, group=mp_group
)
torch.distributed.broadcast(
moe.gate.gate.bias.data, group_sender, group=mp_group
)
gate_idx, gate_score = moe.gate(inp) gate_idx, gate_score = moe.gate(inp)
inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0) inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp).mean() moe_out = moe(inp).mean()
...@@ -47,6 +59,7 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank): ...@@ -47,6 +59,7 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
@pytest.mark.parametrize("d_hidden", [32]) @pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("rank", [0]) @pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
def test_fmoe_linear( def test_fmoe_linear(
num_expert, num_expert,
top_k, top_k,
...@@ -55,6 +68,7 @@ def test_fmoe_linear( ...@@ -55,6 +68,7 @@ def test_fmoe_linear(
d_hidden, d_hidden,
rank, rank,
world_size, world_size,
mp_group,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
...@@ -70,7 +84,7 @@ def test_fmoe_linear( ...@@ -70,7 +84,7 @@ def test_fmoe_linear(
d_model=d_model, d_model=d_model,
gate=NaiveGate, gate=NaiveGate,
world_size=world_size, world_size=world_size,
mp_group=None, mp_group=mp_group,
expert_fn=expert_fn, expert_fn=expert_fn,
top_k=top_k, top_k=top_k,
).cuda() ).cuda()
...@@ -100,7 +114,9 @@ def test_fmoe_linear( ...@@ -100,7 +114,9 @@ def test_fmoe_linear(
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data) torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0) moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k) moe_out, raw_out = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
...@@ -109,8 +125,9 @@ def test_fmoe_linear( ...@@ -109,8 +125,9 @@ def test_fmoe_linear(
_, htoh4_grad, h4toh_grad = raw_out_list _, htoh4_grad, h4toh_grad = raw_out_list
torch.distributed.all_reduce(htoh4_grad) torch.distributed.all_reduce(htoh4_grad)
torch.distributed.all_reduce(h4toh_grad) torch.distributed.all_reduce(h4toh_grad)
htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] mp_size = mp_group.size() if mp_group else 1
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_grad, h4toh_grad raw_out_list = _, htoh4_grad, h4toh_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"] names = ["output", "htoh4 weight grad", "h4toh weight grad"]
...@@ -121,9 +138,10 @@ def test_fmoe_linear( ...@@ -121,9 +138,10 @@ def test_fmoe_linear(
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"]) @pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
@pytest.mark.parametrize("rank", [0]) @pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
def test_fmoe( def test_fmoe(
batch_size, batch_size,
num_expert, num_expert,
...@@ -131,6 +149,7 @@ def test_fmoe( ...@@ -131,6 +149,7 @@ def test_fmoe(
top_k, top_k,
expert: Union[Type[nn.Module], str], expert: Union[Type[nn.Module], str],
rank, rank,
mp_group,
world_size, world_size,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
...@@ -144,7 +163,7 @@ def test_fmoe( ...@@ -144,7 +163,7 @@ def test_fmoe(
d_model=d_model, d_model=d_model,
gate=NaiveGate, gate=NaiveGate,
world_size=world_size, world_size=world_size,
mp_group=None, mp_group=mp_group,
expert=expert, expert=expert,
top_k=top_k, top_k=top_k,
).cuda() ).cuda()
...@@ -178,7 +197,9 @@ def test_fmoe( ...@@ -178,7 +197,9 @@ def test_fmoe(
idx idx
].data = para_tensor_gathered[expertID] ].data = para_tensor_gathered[expertID]
moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k) moe_out, raw_out = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
def get_experts_grad(experts: List[nn.Module]): def get_experts_grad(experts: List[nn.Module]):
return torch.stack( return torch.stack(
...@@ -200,7 +221,8 @@ def test_fmoe( ...@@ -200,7 +221,8 @@ def test_fmoe(
if world_size > 1: if world_size > 1:
torch.distributed.all_reduce(raw_grad) torch.distributed.all_reduce(raw_grad)
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] mp_size = mp_group.size() if mp_group else 1
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
moe_out_list = [moe_out, moe_grad] moe_out_list = [moe_out, moe_grad]
raw_out_list = [raw_out, raw_grad] raw_out_list = [raw_out, raw_grad]
...@@ -218,6 +240,7 @@ if __name__ == "__main__": ...@@ -218,6 +240,7 @@ if __name__ == "__main__":
d_hidden=16, d_hidden=16,
rank=0, rank=0,
world_size=1, world_size=1,
mp_group=None,
) )
test_fmoe( test_fmoe(
batch_size=4, batch_size=4,
...@@ -227,4 +250,5 @@ if __name__ == "__main__": ...@@ -227,4 +250,5 @@ if __name__ == "__main__":
expert=NaiveExpert, expert=NaiveExpert,
rank=0, rank=0,
world_size=1, world_size=1,
mp_group=None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment