moefy.patch 1008 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
diff --git a/megatron/training.py b/megatron/training.py
index 96aec98..fe55dbd 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -33,7 +33,7 @@ from megatron.fp16 import FP16_Module
 from megatron.fp16 import FP16_Optimizer
 from megatron.initialize import initialize_megatron
 from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
 from megatron.model import get_params_for_weight_decay_optimization
 from megatron.model.realm_model import ICTBertModel
 from megatron.utils import check_adlr_autoresume_termination
diff --git a/pretrain_bert.py b/pretrain_bert.py
index b937b36..5841256 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -37,6 +37,8 @@ def model_provider():
         num_tokentypes=2,
         add_binary_head=True,
         parallel_output=True)
+    from fmoe.megatron import fmoefy
+    model = fmoefy(model, num_experts=4)
 
     return model