Commit ae10e942 authored by Rick Ho's avatar Rick Ho
Browse files

update calculation for megatron hhs

parent 07a5d8ac
...@@ -74,11 +74,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -74,11 +74,7 @@ class MegatronMLP(FMoETransformerMLP):
communication group `group` to replace the original MLP layer in Megatron. communication group `group` to replace the original MLP layer in Megatron.
""" """
def __init__(self, args, layer_idx): def __init__(self, args, layer_idx, gate=None):
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
moe_group = None moe_group = None
...@@ -87,7 +83,6 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -87,7 +83,6 @@ class MegatronMLP(FMoETransformerMLP):
from megatron.mpu import get_data_parallel_group from megatron.mpu import get_data_parallel_group
moe_group = get_data_parallel_group() moe_group = get_data_parallel_group()
gate = None
if not args.balance_strategy or args.balance_strategy == "naive": if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
gate = NaiveGate gate = NaiveGate
...@@ -100,7 +95,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -100,7 +95,7 @@ class MegatronMLP(FMoETransformerMLP):
elif args.balance_strategy == "switch": elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate from fmoe.gates import SwitchGate
gate = SwitchGate gate = SwitchGate
else: elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__( super().__init__(
...@@ -152,6 +147,7 @@ def fmoefy( ...@@ -152,6 +147,7 @@ def fmoefy(
distributed_experts=True, distributed_experts=True,
hidden_hidden_size=None, hidden_hidden_size=None,
top_k=None, top_k=None,
gate=None,
): ):
r""" r"""
Replace MLP layers in a transformer-based model in Megatron by MoE. Replace MLP layers in a transformer-based model in Megatron by MoE.
...@@ -186,13 +182,10 @@ def fmoefy( ...@@ -186,13 +182,10 @@ def fmoefy(
elif not hasattr(args, "top_k"): elif not hasattr(args, "top_k"):
args.top_k = 2 args.top_k = 2
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, "hidden_hidden_size"):
args.hidden_hidden_size = args.hidden_size * 4 // args.tensor_model_parallel_size
for idx, l in enumerate(model.language_model.transformer.layers): for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, idx) l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.transformer.layers) num_layers = len(model.language_model.transformer.layers)
......
...@@ -46,17 +46,23 @@ def patch_forward_step(forward_step_func): ...@@ -46,17 +46,23 @@ def patch_forward_step(forward_step_func):
return forward_step_with_balance_loss return forward_step_with_balance_loss
def patch_model_provider(model_provider): def patch_model_provider(model_provider, gate=None):
from megatron import get_args from megatron import get_args
def fmoefied_model_provider(): def fmoefied_model_provider():
from .layers import fmoefy from .layers import fmoefy
args = get_args() args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy( return fmoefy(
model_provider(), model_provider(),
num_experts=args.num_experts, num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k, hidden_hidden_size=hhs,
top_k=args.top_k, top_k=args.top_k,
gate=gate
) )
return fmoefied_model_provider return fmoefied_model_provider
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment