Commit f39f411a authored by Sengxian's avatar Sengxian
Browse files

customized top_k and add test for localddp

parent 27c89b5a
......@@ -24,6 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
else:
world_size = args.world_size
super().__init__(args.num_experts,
top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group)
self.bias = torch.nn.parameter.Parameter(
......@@ -35,7 +36,7 @@ class MegatronMLP(FMoETransformerMLP):
def fmoefy(model, num_experts=None, distributed_experts=True,
hidden_hidden_size=None):
hidden_hidden_size=None, top_k=None):
r'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
......@@ -63,6 +64,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
elif not hasattr(args, 'hidden_hidden_size'):
args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None:
args.top_k = top_k
elif not hasattr(args, 'top_k'):
args.top_k = 2
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
......
......@@ -8,6 +8,7 @@ import torch
from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp
def _run_distributed(func, world_size, args: Dict):
......@@ -78,6 +79,13 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
)
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_local_ddp(mp_size):
_run_distributed(
_test_fmoe_local_ddp.__name__, mp_size * 2, {"mp_size": mp_size},
)
if __name__ == "__main__":
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
......@@ -87,20 +95,30 @@ if __name__ == "__main__":
torch.distributed.init_process_group(backend="nccl")
args["rank"] = torch.distributed.get_rank()
args["world_size"] = torch.distributed.get_world_size()
args["mp_group"] = (
[
torch.distributed.new_group(
ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
backend="nccl",
)
for j in range(args["world_size"] // args["mp_size"])
][args["rank"] // args["mp_size"]]
if args["mp_size"] > 1
else None
args["mp_group"] = [
torch.distributed.new_group(
ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
backend="nccl",
)
for j in range(args["world_size"] // args["mp_size"])
][args["rank"] // args["mp_size"]]
args["dp_group"] = [
torch.distributed.new_group(
ranks=[
i * args["mp_size"] + j
for i in range(args["world_size"] // args["mp_size"])
],
backend="nccl",
)
for j in range(args["mp_size"])
][args["rank"] % args["mp_size"]]
args["world_group"] = torch.distributed.new_group(
ranks=list(range(args["world_size"])), backend="nccl",
)
del args["mp_size"]
locals()[sys.argv[1]](**args)
else:
test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2
)
import sys
from collections import OrderedDict
from typing import List, Type, Union
import pytest
import torch
import torch.nn as nn
from copy import deepcopy
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
......@@ -53,15 +56,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, world_size, mp_group,
top_k, activation):
def __init__(
self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
top_k=top_k
top_k=top_k,
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
......@@ -74,6 +78,8 @@ class MyMoE(FMoE):
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe_linear(
num_expert,
top_k,
......@@ -83,13 +89,16 @@ def test_fmoe_linear(
rank,
world_size,
mp_group,
dp_group,
world_group,
activation=torch.nn.functional.gelu,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
moe = MyMoE(num_expert, d_model, d_hidden, world_size, mp_group, top_k,
activation).cuda()
moe = MyMoE(
num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).cuda()
moe_raw = BruteForceMoELinear(
activation=activation,
......@@ -132,8 +141,20 @@ def test_fmoe_linear(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
moe_out_list = (
moe_out,
moe.experts.htoh4.weight.grad,
moe.experts.h4toh.weight.grad,
moe.experts.htoh4.bias.grad,
moe.experts.h4toh.bias.grad,
)
raw_out_list = (
raw_out,
moe_raw.weight_htoh4.grad,
moe_raw.weight_h4toh.grad,
moe_raw.bias_htoh4.grad,
moe_raw.bias_h4toh.grad,
)
if world_size > 1:
_, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
......@@ -142,13 +163,27 @@ def test_fmoe_linear(
torch.distributed.all_reduce(htoh4_b_grad)
torch.distributed.all_reduce(h4toh_b_grad)
mp_size = mp_group.size() if mp_group else 1
htoh4_w_grad = htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_w_grad = h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
htoh4_b_grad = htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
htoh4_w_grad = (
htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
h4toh_w_grad = (
h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
htoh4_b_grad = (
htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
h4toh_b_grad = (
h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
names = [
"output",
"htoh4 weight grad",
"h4toh weight grad",
"htoh4 bias grad",
"h4toh bias grad",
]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
......@@ -160,6 +195,8 @@ def test_fmoe_linear(
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe(
batch_size,
num_expert,
......@@ -167,8 +204,10 @@ def test_fmoe(
top_k,
expert: Union[Type[nn.Module], str],
rank,
mp_group,
world_size,
mp_group,
dp_group,
world_group,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
......@@ -249,6 +288,82 @@ def test_fmoe(
_assert_numercial(names, moe_out_list, raw_out_list, rank)
class MyModule(nn.Module):
def __init__(self, dim=8):
super(MyModule, self).__init__()
self.model = nn.Sequential(
OrderedDict(
[
("linear1", nn.Linear(dim, dim)),
("relu1", nn.ReLU()),
("linear2", nn.Linear(dim, dim)),
("relu2", nn.ReLU()),
("linear3", nn.Linear(dim, dim)),
]
)
)
def set_comm(self):
for p in self.model._modules["linear1"].parameters():
setattr(p, "dp_comm", "mp")
for p in self.model._modules["linear2"].parameters():
setattr(p, "dp_comm", "dp")
for p in self.model._modules["linear3"].parameters():
setattr(p, "dp_comm", "world")
def forward(self, inp):
return self.model(inp)
def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
batch_size, dim = 4, 8
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
model.set_comm()
model_ddp.module.set_comm()
inp = torch.randn(batch_size, dim).cuda()
raw_out = model(inp).mean()
ddp_out = model_ddp(inp).mean()
raw_out.backward()
ddp_out.backward()
torch.distributed.all_reduce(
model.model._modules["linear1"].weight.grad.data, group=mp_group
)
model.model._modules["linear1"].weight.grad /= mp_group.size()
torch.distributed.all_reduce(
model.model._modules["linear2"].weight.grad.data, group=dp_group
)
model.model._modules["linear2"].weight.grad /= dp_group.size()
torch.distributed.all_reduce(
model.model._modules["linear3"].weight.grad.data, group=world_group
)
model.model._modules["linear3"].weight.grad /= world_group.size()
model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)
raw_out_list = [
model.model._modules["linear1"].weight.grad,
model.model._modules["linear2"].weight.grad,
model.model._modules["linear3"].weight.grad,
]
ddp_out_list = [
model_ddp.module.model._modules["linear1"].weight.grad,
model_ddp.module.model._modules["linear2"].weight.grad,
model_ddp.module.model._modules["linear3"].weight.grad,
]
names = ["mp grad", "dp grad", "wp grad"]
_assert_numercial(names, ddp_out_list, raw_out_list, rank)
if __name__ == "__main__":
test_fmoe_linear(
batch_size=4,
......@@ -259,6 +374,8 @@ if __name__ == "__main__":
rank=0,
world_size=1,
mp_group=None,
dp_group=None,
world_group=None,
)
test_fmoe(
batch_size=4,
......@@ -269,4 +386,6 @@ if __name__ == "__main__":
rank=0,
world_size=1,
mp_group=None,
dp_group=None,
world_group=None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment