fmoefy-v2.1.patch 1.09 KB
Newer Older
1
diff --git a/megatron/training.py b/megatron/training.py
Rick Ho's avatar
Rick Ho committed
2
index 56d1c7c..9c624d2 100644
3
4
--- a/megatron/training.py
+++ b/megatron/training.py
Rick Ho's avatar
Rick Ho committed
5
@@ -43,7 +43,8 @@ from megatron.optimizer import get_megatron_optimizer
6
 from megatron.initialize import initialize_megatron
Rick Ho's avatar
Rick Ho committed
7
 from megatron.initialize import write_args_to_tensorboard
8
9
 from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
Rick Ho's avatar
Rick Ho committed
10
+# from megatron.model import DistributedDataParallel as LocalDDP
11
12
13
+from fmoe.megatron import DistributedDataParallel as LocalDDP
 from megatron.model.realm_model import ICTBertModel
 from megatron.utils import check_adlr_autoresume_termination
Rick Ho's avatar
Rick Ho committed
14
 from megatron.data.data_loaders import build_pretraining_data_loader
15
diff --git a/pretrain_bert.py b/pretrain_bert.py
Rick Ho's avatar
Rick Ho committed
16
index 48bc6ad..48628ce 100644
17
18
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
Rick Ho's avatar
Rick Ho committed
19
20
21
22
@@ -52,6 +52,8 @@ def model_provider():
             num_tokentypes=2,
             add_binary_head=True,
             parallel_output=True)
23
24
25
26
27
+    from fmoe.megatron import fmoefy
+    model = fmoefy(model, num_experts=4)
 
     return model