Commit ae2c434e authored by Rick Ho's avatar Rick Ho
Browse files

fix pure data parallel

parent 6b8d2f2e
......@@ -5,11 +5,15 @@ def create_moe_mlp(args, model_parallel_rank, group):
assert (
args.seq_length * args.batch_size % args.model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size"
if args.model_parallel_size == 1:
world_size = 1
else:
world_size = args.world_size
fmoe = FMoETransformerMLP(
args.num_experts,
d_model=args.hidden_size,
d_hidden=args.hidden_size * 4,
world_size=args.world_size,
world_size=world_size,
model_parallel_size=args.model_parallel_size,
model_parallel_rank=model_parallel_rank,
mp_group=group,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment