Commit ae10e942 authored by Rick Ho's avatar Rick Ho
Browse files

update calculation for megatron hhs

parent 07a5d8ac
......@@ -74,11 +74,7 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron.
"""
def __init__(self, args, layer_idx):
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
def __init__(self, args, layer_idx, gate=None):
if not args.distributed_experts:
world_size = 1
moe_group = None
......@@ -87,7 +83,6 @@ class MegatronMLP(FMoETransformerMLP):
from megatron.mpu import get_data_parallel_group
moe_group = get_data_parallel_group()
gate = None
if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate
gate = NaiveGate
......@@ -100,7 +95,7 @@ class MegatronMLP(FMoETransformerMLP):
elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate
gate = SwitchGate
else:
elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__(
......@@ -152,6 +147,7 @@ def fmoefy(
distributed_experts=True,
hidden_hidden_size=None,
top_k=None,
gate=None,
):
r"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
......@@ -186,13 +182,10 @@ def fmoefy(
elif not hasattr(args, "top_k"):
args.top_k = 2
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4 // args.tensor_model_parallel_size
args.hidden_hidden_size = hidden_hidden_size
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, idx)
l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook
num_layers = len(model.language_model.transformer.layers)
......
......@@ -46,17 +46,23 @@ def patch_forward_step(forward_step_func):
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
def patch_model_provider(model_provider, gate=None):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate
)
return fmoefied_model_provider
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment