Commit 5f8ba136 authored by Rick Ho's avatar Rick Ho
Browse files

tide up megatron compatible layer

parent 33fc3aca
......@@ -107,8 +107,8 @@ class FMoE(nn.Module):
self.slice_size = 1
self.slice_rank = 0
else:
self.slice_size = slice_group.size()
self.slice_rank = slice_group.rank()
self.slice_size = self.slice_group.size()
self.slice_rank = self.slice_group.rank()
self.top_k = top_k
if type(expert) is list:
......
......@@ -15,5 +15,6 @@ from .balance import reset_gate_hook
from .balance import get_balance_profile
from .balance import generate_megatron_gate_hook
from .balance import add_balance_log
from .balance import patch_forward_step
from .balance import patch_model_provider
from .patch import patch_forward_step
from .patch import patch_model_provider
......@@ -5,7 +5,6 @@ import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {}
......@@ -71,63 +70,3 @@ def add_balance_log(model, writer, iteration):
balance_dict_tensor[idx].mean().item(),
iteration,
)
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
top_k=args.top_k,
)
return fmoefied_model_provider
r"""
distributed support for Megatron
"""
import torch
from fmoe.distributed import DistributedGroupedDataParallel
_moe_group = None
_groups = None
def _set_groups(**kwargs):
global _groups
_groups = kwargs
def set_moe_group(moe_group):
global _moe_group
_moe_group = moe_group
def _init():
from megatron import get_args
from megatron import mpu
args = get_args()
# Create a comm prependicular to pipeline group as gate group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
gate_group = group
def get_moe_group():
return _moe_group
_set_groups(
dp_group=mpu.get_data_parallel_group(),
moe_group=mpu.get_data_parallel_group(),
gate_group=gate_group)
class DistributedDataParallel(DistributedGroupedDataParallel):
......@@ -24,14 +41,9 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
"""
def __init__(self, module):
from megatron import mpu
super().__init__(
module,
mp_group=mpu.get_model_parallel_group(),
dp_group=mpu.get_data_parallel_group(),
moe_group=_moe_group
)
if _groups is None:
_init()
super().__init__(module, **_groups)
def state_dict(self, *args, **kwargs):
r"""
......
......@@ -10,7 +10,6 @@ import torch.nn.functional as F
from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook
from .distributed import set_moe_group
class _FakeMegatronMLP(nn.Module):
......@@ -75,31 +74,31 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron.
"""
def __init__(self, args, mp_group, moe_group, layer_idx):
def __init__(self, args, layer_idx):
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts:
world_size = 1
moe_group = None
else:
world_size = args.tensor_model_parallel_size * args.data_parallel_size
world_size = args.data_parallel_size
from megatron.mpu import get_data_parallel_group
moe_group = get_data_parallel_group()
gate = None
if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate
gate = NaiveGate
elif args.balance_strategy == "noisy":
from fmoe.gates import NoisyGate
gate = NoisyGate
elif args.balance_strategy == "gshard":
from fmoe.gates import GShardGate
gate = GShardGate
elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate
gate = SwitchGate
else:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
......@@ -110,7 +109,6 @@ class MegatronMLP(FMoETransformerMLP):
d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size,
world_size=world_size,
mp_group=mp_group,
moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook(
......@@ -139,8 +137,11 @@ class MegatronMLP(FMoETransformerMLP):
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp):
from megatron import mpu
x = super().forward(inp)
x = mpu.reduce_from_tensor_model_parallel_region(x)
return (
super().forward(inp),
x,
torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
)
......@@ -167,47 +168,31 @@ def fmoefy(
tensor_model_parall_comm x data_parallel_comm, which is not created.
"""
from megatron import get_args
from megatron import mpu
args = get_args()
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
if num_experts is not None:
args.num_experts = num_experts
assert (
"num_experts" in args
), "num_experts should be specified in arguments or fmoefy function"
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4
if top_k is not None:
args.top_k = top_k
elif not hasattr(args, "top_k"):
args.top_k = 2
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4 // args.tensor_model_parallel_size
if hasattr(mpu, 'get_tensor_model_parallel_group'):
mp_group = mpu.get_tensor_model_parallel_group()
else:
# For compatibility to older versions of Megatron-LM
mp_group = mpu.get_model_parallel_group()
if args.pipeline_model_parallel_size == 1:
moe_group = None
else:
# Create a comm prependicular to pipeline group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
moe_group = group
set_moe_group(moe_group)
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, mp_group, moe_group, idx)
l.mlp = MegatronMLP(args, idx)
# initialize gate hook
num_layers = len(model.language_model.transformer.layers)
......
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron.mpu import get_tensor_model_parallel_group
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_tensor_model_parallel_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
top_k=args.top_k,
)
return fmoefied_model_provider
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment