diff --git a/megatron/training.py b/megatron/training.py index 56d1c7c..9c624d2 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -43,7 +43,8 @@ from megatron.optimizer import get_megatron_optimizer from megatron.initialize import initialize_megatron from megatron.initialize import write_args_to_tensorboard from megatron.learning_rates import AnnealingLR -from megatron.model import DistributedDataParallel as LocalDDP +# from megatron.model import DistributedDataParallel as LocalDDP +from fmoe.megatron import DistributedDataParallel as LocalDDP from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.data.data_loaders import build_pretraining_data_loader diff --git a/pretrain_bert.py b/pretrain_bert.py index 48bc6ad..48628ce 100644 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -52,6 +52,8 @@ def model_provider(): num_tokentypes=2, add_binary_head=True, parallel_output=True) + from fmoe.megatron import fmoefy + model = fmoefy(model, num_experts=4) return model