megatron.py 603 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
Rick Ho's avatar
Rick Ho committed
2
3


Sengxian's avatar
Sengxian committed
4
def create_moe_mlp(args, model_parallel_rank, group):
5
    assert (
6
        args.seq_length * args.batch_size % args.model_parallel_size == 0
7
8
9
10
11
12
    ), "Num experts should be multiple of mp size"
    num_experts = args.num_experts // args.model_parallel_size
    fmoe = FMoETransformerMLP(
        num_experts,
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Sengxian's avatar
Sengxian committed
13
14
15
16
        world_size=args.world_size,
        model_parallel_size=args.model_parallel_size,
        model_parallel_rank=model_parallel_rank,
        group=group,
17
    )
Rick Ho's avatar
Rick Ho committed
18
    return fmoe