Commit ae2c434e authored by Rick Ho's avatar Rick Ho
Browse files

fix pure data parallel

parent 6b8d2f2e
...@@ -5,11 +5,15 @@ def create_moe_mlp(args, model_parallel_rank, group): ...@@ -5,11 +5,15 @@ def create_moe_mlp(args, model_parallel_rank, group):
assert ( assert (
args.seq_length * args.batch_size % args.model_parallel_size == 0 args.seq_length * args.batch_size % args.model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
if args.model_parallel_size == 1:
world_size = 1
else:
world_size = args.world_size
fmoe = FMoETransformerMLP( fmoe = FMoETransformerMLP(
args.num_experts, args.num_experts,
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_size * 4, d_hidden=args.hidden_size * 4,
world_size=args.world_size, world_size=world_size,
model_parallel_size=args.model_parallel_size, model_parallel_size=args.model_parallel_size,
model_parallel_rank=model_parallel_rank, model_parallel_rank=model_parallel_rank,
mp_group=group, mp_group=group,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment