diff --git a/megatron/arguments.py b/megatron/arguments.py index 26a7cec..0acfb22 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -21,6 +21,8 @@ import os import torch from megatron import fused_kernels +from fmoe.megatron import add_fmoe_args as _add_fmoe_args + def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): """Parse all arguments.""" @@ -40,6 +42,7 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_data_args(parser) parser = _add_autoresume_args(parser) parser = _add_realm_args(parser) + parser = _add_fmoe_args(parser) # Custom arguments. if extra_args_provider is not None: diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 9d42260..2583db2 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -177,6 +177,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): param) if hasattr(param, 'shared'): main_param.shared = param.shared + if hasattr(param, 'dp_comm'): + main_param.dp_comm = param.dp_comm # Replace the optimizer params with the new fp32 copy. param_group['params'][i] = main_param fp32_from_fp16_params_this_group.append(main_param) diff --git a/megatron/training.py b/megatron/training.py index 56d1c7c..f825bf3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -35,20 +35,24 @@ from megatron import update_num_microbatches from megatron import mpu from megatron import print_rank_0 from megatron import print_rank_last -from megatron.checkpointing import load_checkpoint -from megatron.checkpointing import save_checkpoint +# from megatron.checkpointing import load_checkpoint +from fmoe.megatron.checkpoint import load_checkpoint +# from megatron.checkpointing import save_checkpoint +from fmoe.megatron.checkpoint import save_checkpoint from megatron.model import FP16Module from megatron.optimizer import get_megatron_optimizer from megatron.initialize import initialize_megatron from megatron.initialize import write_args_to_tensorboard from megatron.learning_rates import AnnealingLR -from megatron.model import DistributedDataParallel as LocalDDP +# from megatron.model import DistributedDataParallel as LocalDDP from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.data.data_loaders import build_pretraining_data_loader from megatron.utils import report_memory +from fmoe.megatron import DistributedDataParallel as LocalDDP +from fmoe.megatron import add_balance_log def print_datetime(string): """Note that this call will sync across all ranks.""" @@ -102,6 +106,13 @@ def pretrain(train_valid_test_dataset_provider, model_provider, args = get_args() timers = get_timers() + # Initialize FastMoE + if args.fmoefy: + from fmoe.megatron import patch_forward_step, patch_model_provider + + forward_step_func = patch_forward_step(forward_step_func) + model_provider = patch_model_provider(model_provider) + # Model, optimizer, and learning rate. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) @@ -643,7 +654,7 @@ def train_step(forward_step_func, data_iterator, def training_log(loss_dict, total_loss_dict, learning_rate, iteration, - loss_scale, report_memory_flag, skipped_iter): + loss_scale, report_memory_flag, skipped_iter, model): """Log training information such as losses, timing, ....""" args = get_args() timers = get_timers() @@ -725,6 +736,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, args.consumed_train_samples) timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) + if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive': + add_balance_log(model, writer, iteration) if iteration % args.log_interval == 0: elapsed_time = timers('interval time').elapsed() @@ -816,7 +829,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, - report_memory_flag, skipped_iter) + report_memory_flag, skipped_iter, model) # Autoresume if args.adlr_autoresume and \