diff --git a/megatron/training.py b/megatron/training.py index 96aec98..fe55dbd 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -33,7 +33,7 @@ from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Optimizer from megatron.initialize import initialize_megatron from megatron.learning_rates import AnnealingLR -from megatron.model import DistributedDataParallel as LocalDDP +from fmoe.megatron import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination diff --git a/pretrain_bert.py b/pretrain_bert.py index b937b36..5841256 100644 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -37,6 +37,8 @@ def model_provider(): num_tokentypes=2, add_binary_head=True, parallel_output=True) + from fmoe.megatron import fmoefy + model = fmoefy(model, num_experts=4) return model