Commit 66f7166d authored by Rick Ho's avatar Rick Ho
Browse files

let fmoe.megatron use correct mpu

parent b0990e4b
...@@ -99,6 +99,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True, ...@@ -99,6 +99,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
tensor_model_parall_comm x data_parallel_comm, which is not created. tensor_model_parall_comm x data_parallel_comm, which is not created.
''' '''
from megatron import get_args from megatron import get_args
from megatron import mpu
args = get_args() args = get_args()
if num_experts is not None: if num_experts is not None:
args.num_experts = num_experts args.num_experts = num_experts
...@@ -121,7 +122,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True, ...@@ -121,7 +122,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers: for l in model.language_model.transformer.layers:
l.mlp = MegatronMLP(args, None) l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment