Unverified Commit a6d202a6 authored by Jiezhong Qiu's avatar Jiezhong Qiu Committed by GitHub
Browse files

Megatron v3.0.2 Patch (#159)

* support megatron v3.0.2

* keep num_experts for lower version of megatron
parent d56522bc
...@@ -59,7 +59,7 @@ multiple experts. ...@@ -59,7 +59,7 @@ multiple experts.
model = ... model = ...
from fmoe.megatron import fmoefy from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=<number of experts per worker>) model = fmoefy(model, fmoe_num_experts=<number of experts per worker>)
train(model, ...) train(model, ...)
``` ```
......
...@@ -50,7 +50,7 @@ Transformer 模型变为一个 MoE 的模型. 其使用方法如下. ...@@ -50,7 +50,7 @@ Transformer 模型变为一个 MoE 的模型. 其使用方法如下.
model = ... model = ...
from fmoe.megatron import fmoefy from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=<number of experts per worker>) model = fmoefy(model, fmoe_num_experts=<number of experts per worker>)
train(model, ...) train(model, ...)
``` ```
......
...@@ -25,7 +25,7 @@ transformer language models. ...@@ -25,7 +25,7 @@ transformer language models.
```python ```python
from fmoe.megatron import fmoefy from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=4) model = fmoefy(model, fmoe_num_experts=4)
``` ```
Note that the `fmoefy` function currently only takes a standard Megatron-LM's Note that the `fmoefy` function currently only takes a standard Megatron-LM's
......
This diff is collapsed.
...@@ -54,6 +54,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -54,6 +54,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import print_rank_last from megatron import print_rank_last
from megatron import utils
expert_dp_comm = "none" expert_dp_comm = "none"
...@@ -67,8 +68,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -67,8 +68,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if hasattr(model, 'module'): try:
model = model.module model = utils.unwrap_model(model)
except AttributeError:
# fallback to the old way of unwrapping a model
if hasattr(model, 'module'):
model = model.module
model = [model,]
print_rank_last( print_rank_last(
"saving checkpoint at iteration {:7d} to {}".format(iteration, args.save) "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
...@@ -76,7 +82,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -76,7 +82,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict["model"] = model.state_dict_for_save_checkpoint( assert len(model) == 1, "FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
state_dict["model"] = model[0].state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0) keep_vars=(mpu.get_data_parallel_rank() > 0)
) )
...@@ -215,6 +222,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -215,6 +222,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import print_rank_last from megatron import print_rank_last
from megatron import utils
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import set_checkpoint_version from megatron.checkpointing import set_checkpoint_version
from megatron.checkpointing import check_checkpoint_args from megatron.checkpointing import check_checkpoint_args
...@@ -229,8 +237,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -229,8 +237,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if hasattr(model, 'module'): # Only rank zero of the data parallel writes to the disk.
model = model.module try:
model = utils.unwrap_model(model)
except AttributeError:
# fallback to the old way of unwrapping a model
if hasattr(model, 'module'):
model = model.module
model = [model,]
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -341,7 +356,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -341,7 +356,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
print_rank_last("could not find arguments in the checkpoint ...") print_rank_last("could not find arguments in the checkpoint ...")
# Model. # Model.
model.load_state_dict(state_dict["model"]) assert len(model) == 1, "FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
model[0].load_state_dict(state_dict["model"])
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
...@@ -350,9 +366,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -350,9 +366,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
optimizer.load_state_dict(state_dict["optimizer"]) optimizer.load_state_dict(state_dict["optimizer"])
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
except KeyError: except KeyError as e:
print_rank_last( print_rank_last(
"Unable to load optimizer from checkpoint {}. " "FMoE is unable to load optimizer from checkpoint {}. "
"Specify --no-load-optim or --finetune to prevent " "Specify --no-load-optim or --finetune to prevent "
"attempting to load the optimizer state, " "attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name_local) "exiting ...".format(checkpoint_name_local)
...@@ -367,9 +383,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -367,9 +383,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
torch.set_rng_state(state_dict["torch_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"])
torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"]) mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
except KeyError: except KeyError as e:
print_rank_last(e)
print_rank_last( print_rank_last(
"Unable to load optimizer from checkpoint {}. " "FMoE is unable to load rng state from checkpoint {}. "
"Specify --no-load-rng or --finetune to prevent " "Specify --no-load-rng or --finetune to prevent "
"attempting to load the optimizer state, " "attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name_local) "exiting ...".format(checkpoint_name_local)
......
...@@ -102,7 +102,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -102,7 +102,7 @@ class MegatronMLP(FMoETransformerMLP):
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__( super().__init__(
args.num_experts, args.fmoe_num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size, d_hidden=args.hidden_hidden_size,
...@@ -110,7 +110,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -110,7 +110,7 @@ class MegatronMLP(FMoETransformerMLP):
moe_group=moe_group, moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp", expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook( gate_hook=generate_megatron_gate_hook(
layer_idx, args.num_experts * world_size layer_idx, args.fmoe_num_experts * world_size
), ),
gate=gate, gate=gate,
) )
...@@ -157,7 +157,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -157,7 +157,7 @@ class MegatronMLP(FMoETransformerMLP):
def fmoefy( def fmoefy(
model, model,
num_experts=None, fmoe_num_experts=None,
distributed_experts=True, distributed_experts=True,
hidden_hidden_size=None, hidden_hidden_size=None,
top_k=None, top_k=None,
...@@ -183,11 +183,11 @@ def fmoefy( ...@@ -183,11 +183,11 @@ def fmoefy(
if distributed_experts is not None: if distributed_experts is not None:
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
if num_experts is not None: if fmoe_num_experts is not None:
args.num_experts = num_experts args.fmoe_num_experts = fmoe_num_experts
assert ( assert (
"num_experts" in args "fmoe_num_experts" in args
), "num_experts should be specified in arguments or fmoefy function" ), "fmoe_num_experts should be specified in arguments or fmoefy function"
if top_k is not None: if top_k is not None:
args.top_k = top_k args.top_k = top_k
...@@ -203,19 +203,20 @@ def fmoefy( ...@@ -203,19 +203,20 @@ def fmoefy(
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.transformer.layers) num_layers = len(model.language_model.transformer.layers)
elif megatron_version == "v2.5": elif megatron_version in ["v2.5", "v3.0.2"]:
for idx, l in enumerate(model.language_model.encoder.layers): for idx, l in enumerate(model.language_model.encoder.layers):
l.mlp = MegatronMLP(args, idx, gate=gate) l.mlp = MegatronMLP(args, idx, gate=gate)
if hasattr(model.language_model, "decoder"): if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
for idx, l in enumerate(model.language_model.decoder.layers): for idx, l in enumerate(model.language_model.decoder.layers):
l.mlp = MegatronMLP(args, idx, gate=gate) l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.encoder.layers) num_layers = len(model.language_model.encoder.layers)
if hasattr(model.language_model, "decoder"): if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
num_layers += len(model.language_model.decoder.layers) num_layers += len(model.language_model.decoder.layers)
else: else:
print(model.language_model)
assert False, f"megatron_version {megatron_version} not known." assert False, f"megatron_version {megatron_version} not known."
reset_gate_hook(num_layers) reset_gate_hook(num_layers)
......
...@@ -30,7 +30,7 @@ def patch_loss_func_v2_5(loss_func): ...@@ -30,7 +30,7 @@ def patch_loss_func_v2_5(loss_func):
for l in model.language_model.encoder.layers for l in model.language_model.encoder.layers
if l.mlp.gate.has_loss] if l.mlp.gate.has_loss]
if hasattr(model.language_model, "decoder"): if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1) loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.decoder.layers for l in model.language_model.decoder.layers
if l.mlp.gate.has_loss] if l.mlp.gate.has_loss]
...@@ -125,6 +125,8 @@ def patch_forward_step(forward_step_func, Megatron_Version="v2.2"): ...@@ -125,6 +125,8 @@ def patch_forward_step(forward_step_func, Megatron_Version="v2.2"):
return forward_step_with_balance_loss_v2_2 return forward_step_with_balance_loss_v2_2
elif Megatron_Version == "v2.5": elif Megatron_Version == "v2.5":
return forward_step_with_balance_loss_v2_5 return forward_step_with_balance_loss_v2_5
elif Megatron_Version == "v3.0.2":
return forward_step_with_balance_loss_v2_5
else: else:
assert False, f"megatron version {Megatron_Version} not known." assert False, f"megatron version {Megatron_Version} not known."
...@@ -143,7 +145,7 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'): ...@@ -143,7 +145,7 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
hhs = hhs // args.tensor_model_parallel_size hhs = hhs // args.tensor_model_parallel_size
return fmoefy( return fmoefy(
model_provider(), model_provider(),
num_experts=args.num_experts, fmoe_num_experts=args.fmoe_num_experts,
hidden_hidden_size=hhs, hidden_hidden_size=hhs,
top_k=args.top_k, top_k=args.top_k,
gate=gate, gate=gate,
...@@ -160,16 +162,35 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'): ...@@ -160,16 +162,35 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
hhs = hhs // args.tensor_model_parallel_size hhs = hhs // args.tensor_model_parallel_size
return fmoefy( return fmoefy(
model_provider(pre_process=pre_process, post_process=post_process), model_provider(pre_process=pre_process, post_process=post_process),
num_experts=args.num_experts, fmoe_num_experts=args.fmoe_num_experts,
hidden_hidden_size=hhs, hidden_hidden_size=hhs,
top_k=args.top_k, top_k=args.top_k,
gate=gate, gate=gate,
megatron_version="v2.5" megatron_version="v2.5"
) )
def fmoefied_model_provider_v3_0_2(pre_process, post_process):
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(pre_process=pre_process, post_process=post_process),
fmoe_num_experts=args.fmoe_num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate,
megatron_version="v3.0.2"
)
if Megatron_Version == 'v2.2': if Megatron_Version == 'v2.2':
return fmoefied_model_provider_v2_2 return fmoefied_model_provider_v2_2
elif Megatron_Version == 'v2.5': elif Megatron_Version == 'v2.5':
return fmoefied_model_provider_v2_5 return fmoefied_model_provider_v2_5
elif Megatron_Version == 'v3.0.2':
return fmoefied_model_provider_v3_0_2
else: else:
assert False, f"Megatron Version {Megatron_Version} unknown." assert False, f"Megatron Version {Megatron_Version} unknown."
r""" r"""
Utility in Megatron Utility in Megatron
""" """
import argparse
def add_fmoe_args(parser): def add_fmoe_args(parser):
group = parser.add_argument_group(title="fastmoe") group = parser.add_argument_group(title="fastmoe")
group.add_argument("--fmoefy", action="store_true") group.add_argument("--fmoefy", action="store_true")
group.add_argument("--num-experts", type=int, default=None) try:
group.add_argument("--num-experts", type=int, default=None)
except argparse.ArgumentError:
group.add_argument("--fmoe-num-experts", type=int, default=None)
group.add_argument("--top-k", type=int, default=2) group.add_argument("--top-k", type=int, default=2)
group.add_argument("--balance-loss-weight", type=float, default=1) group.add_argument("--balance-loss-weight", type=float, default=1)
group.add_argument("--balance-strategy", type=str, default=None) group.add_argument("--balance-strategy", type=str, default=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment