Commit f39f411a authored by Sengxian's avatar Sengxian
Browse files

customized top_k and add test for localddp

parent 27c89b5a
...@@ -24,6 +24,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -24,6 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
else: else:
world_size = args.world_size world_size = args.world_size
super().__init__(args.num_experts, super().__init__(args.num_experts,
top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size, d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group) world_size=world_size, mp_group=group)
self.bias = torch.nn.parameter.Parameter( self.bias = torch.nn.parameter.Parameter(
...@@ -35,7 +36,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -35,7 +36,7 @@ class MegatronMLP(FMoETransformerMLP):
def fmoefy(model, num_experts=None, distributed_experts=True, def fmoefy(model, num_experts=None, distributed_experts=True,
hidden_hidden_size=None): hidden_hidden_size=None, top_k=None):
r''' r'''
Replace MLP layers in a transformer-based model in Megatron by MoE. Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has * `model` should be a standard Megatron model that has
...@@ -63,6 +64,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True, ...@@ -63,6 +64,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
elif not hasattr(args, 'hidden_hidden_size'): elif not hasattr(args, 'hidden_hidden_size'):
args.hidden_hidden_size = args.hidden_size * 4 args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None:
args.top_k = top_k
elif not hasattr(args, 'top_k'):
args.top_k = 2
# Set distributed_experts to None to use default setting in args # Set distributed_experts to None to use default setting in args
if distributed_experts is not None: if distributed_experts is not None:
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from test_numerical import test_fmoe as _test_fmoe from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp
def _run_distributed(func, world_size, args: Dict): def _run_distributed(func, world_size, args: Dict):
...@@ -78,6 +79,13 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz ...@@ -78,6 +79,13 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
) )
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_local_ddp(mp_size):
_run_distributed(
_test_fmoe_local_ddp.__name__, mp_size * 2, {"mp_size": mp_size},
)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) >= 3: if len(sys.argv) >= 3:
args = json.loads(sys.argv[2]) args = json.loads(sys.argv[2])
...@@ -87,20 +95,30 @@ if __name__ == "__main__": ...@@ -87,20 +95,30 @@ if __name__ == "__main__":
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
args["rank"] = torch.distributed.get_rank() args["rank"] = torch.distributed.get_rank()
args["world_size"] = torch.distributed.get_world_size() args["world_size"] = torch.distributed.get_world_size()
args["mp_group"] = ( args["mp_group"] = [
[ torch.distributed.new_group(
torch.distributed.new_group( ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])], backend="nccl",
backend="nccl", )
) for j in range(args["world_size"] // args["mp_size"])
for j in range(args["world_size"] // args["mp_size"]) ][args["rank"] // args["mp_size"]]
][args["rank"] // args["mp_size"]] args["dp_group"] = [
if args["mp_size"] > 1 torch.distributed.new_group(
else None ranks=[
i * args["mp_size"] + j
for i in range(args["world_size"] // args["mp_size"])
],
backend="nccl",
)
for j in range(args["mp_size"])
][args["rank"] % args["mp_size"]]
args["world_group"] = torch.distributed.new_group(
ranks=list(range(args["world_size"])), backend="nccl",
) )
del args["mp_size"] del args["mp_size"]
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
else: else:
test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed( test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2 num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2
) )
import sys import sys
from collections import OrderedDict
from typing import List, Type, Union from typing import List, Type, Union
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
from fmoe.layers import FMoE from fmoe.layers import FMoE
from fmoe.transformer import _Expert from fmoe.transformer import _Expert
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
...@@ -53,15 +56,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank): ...@@ -53,15 +56,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
class MyMoE(FMoE): class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, world_size, mp_group, def __init__(
top_k, activation): self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
):
super().__init__( super().__init__(
num_expert=num_expert, num_expert=num_expert,
d_model=d_model, d_model=d_model,
gate=NaiveGate, gate=NaiveGate,
world_size=world_size, world_size=world_size,
mp_group=mp_group, mp_group=mp_group,
top_k=top_k top_k=top_k,
) )
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation)
...@@ -74,6 +78,8 @@ class MyMoE(FMoE): ...@@ -74,6 +78,8 @@ class MyMoE(FMoE):
@pytest.mark.parametrize("rank", [0]) @pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe_linear( def test_fmoe_linear(
num_expert, num_expert,
top_k, top_k,
...@@ -83,13 +89,16 @@ def test_fmoe_linear( ...@@ -83,13 +89,16 @@ def test_fmoe_linear(
rank, rank,
world_size, world_size,
mp_group, mp_group,
dp_group,
world_group,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
moe = MyMoE(num_expert, d_model, d_hidden, world_size, mp_group, top_k, moe = MyMoE(
activation).cuda() num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).cuda()
moe_raw = BruteForceMoELinear( moe_raw = BruteForceMoELinear(
activation=activation, activation=activation,
...@@ -132,8 +141,20 @@ def test_fmoe_linear( ...@@ -132,8 +141,20 @@ def test_fmoe_linear(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
) )
moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad moe_out_list = (
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad moe_out,
moe.experts.htoh4.weight.grad,
moe.experts.h4toh.weight.grad,
moe.experts.htoh4.bias.grad,
moe.experts.h4toh.bias.grad,
)
raw_out_list = (
raw_out,
moe_raw.weight_htoh4.grad,
moe_raw.weight_h4toh.grad,
moe_raw.bias_htoh4.grad,
moe_raw.bias_h4toh.grad,
)
if world_size > 1: if world_size > 1:
_, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
...@@ -142,13 +163,27 @@ def test_fmoe_linear( ...@@ -142,13 +163,27 @@ def test_fmoe_linear(
torch.distributed.all_reduce(htoh4_b_grad) torch.distributed.all_reduce(htoh4_b_grad)
torch.distributed.all_reduce(h4toh_b_grad) torch.distributed.all_reduce(h4toh_b_grad)
mp_size = mp_group.size() if mp_group else 1 mp_size = mp_group.size() if mp_group else 1
htoh4_w_grad = htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size htoh4_w_grad = (
h4toh_w_grad = h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
htoh4_b_grad = htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size )
h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size h4toh_w_grad = (
h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
htoh4_b_grad = (
htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
h4toh_b_grad = (
h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"] names = [
"output",
"htoh4 weight grad",
"h4toh weight grad",
"htoh4 bias grad",
"h4toh bias grad",
]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numercial(names, moe_out_list, raw_out_list, rank)
...@@ -160,6 +195,8 @@ def test_fmoe_linear( ...@@ -160,6 +195,8 @@ def test_fmoe_linear(
@pytest.mark.parametrize("rank", [0]) @pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe( def test_fmoe(
batch_size, batch_size,
num_expert, num_expert,
...@@ -167,8 +204,10 @@ def test_fmoe( ...@@ -167,8 +204,10 @@ def test_fmoe(
top_k, top_k,
expert: Union[Type[nn.Module], str], expert: Union[Type[nn.Module], str],
rank, rank,
mp_group,
world_size, world_size,
mp_group,
dp_group,
world_group,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
...@@ -249,6 +288,82 @@ def test_fmoe( ...@@ -249,6 +288,82 @@ def test_fmoe(
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numercial(names, moe_out_list, raw_out_list, rank)
class MyModule(nn.Module):
def __init__(self, dim=8):
super(MyModule, self).__init__()
self.model = nn.Sequential(
OrderedDict(
[
("linear1", nn.Linear(dim, dim)),
("relu1", nn.ReLU()),
("linear2", nn.Linear(dim, dim)),
("relu2", nn.ReLU()),
("linear3", nn.Linear(dim, dim)),
]
)
)
def set_comm(self):
for p in self.model._modules["linear1"].parameters():
setattr(p, "dp_comm", "mp")
for p in self.model._modules["linear2"].parameters():
setattr(p, "dp_comm", "dp")
for p in self.model._modules["linear3"].parameters():
setattr(p, "dp_comm", "world")
def forward(self, inp):
return self.model(inp)
def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
batch_size, dim = 4, 8
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
model.set_comm()
model_ddp.module.set_comm()
inp = torch.randn(batch_size, dim).cuda()
raw_out = model(inp).mean()
ddp_out = model_ddp(inp).mean()
raw_out.backward()
ddp_out.backward()
torch.distributed.all_reduce(
model.model._modules["linear1"].weight.grad.data, group=mp_group
)
model.model._modules["linear1"].weight.grad /= mp_group.size()
torch.distributed.all_reduce(
model.model._modules["linear2"].weight.grad.data, group=dp_group
)
model.model._modules["linear2"].weight.grad /= dp_group.size()
torch.distributed.all_reduce(
model.model._modules["linear3"].weight.grad.data, group=world_group
)
model.model._modules["linear3"].weight.grad /= world_group.size()
model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)
raw_out_list = [
model.model._modules["linear1"].weight.grad,
model.model._modules["linear2"].weight.grad,
model.model._modules["linear3"].weight.grad,
]
ddp_out_list = [
model_ddp.module.model._modules["linear1"].weight.grad,
model_ddp.module.model._modules["linear2"].weight.grad,
model_ddp.module.model._modules["linear3"].weight.grad,
]
names = ["mp grad", "dp grad", "wp grad"]
_assert_numercial(names, ddp_out_list, raw_out_list, rank)
if __name__ == "__main__": if __name__ == "__main__":
test_fmoe_linear( test_fmoe_linear(
batch_size=4, batch_size=4,
...@@ -259,6 +374,8 @@ if __name__ == "__main__": ...@@ -259,6 +374,8 @@ if __name__ == "__main__":
rank=0, rank=0,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
dp_group=None,
world_group=None,
) )
test_fmoe( test_fmoe(
batch_size=4, batch_size=4,
...@@ -269,4 +386,6 @@ if __name__ == "__main__": ...@@ -269,4 +386,6 @@ if __name__ == "__main__":
rank=0, rank=0,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
dp_group=None,
world_group=None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment