megatron.py 586 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
Rick Ho's avatar
Rick Ho committed
2
3


Sengxian's avatar
Sengxian committed
4
def create_moe_mlp(args, model_parallel_rank, group):
5
6
7
8
9
10
11
12
    assert (
        args.num_experts % args.model_parallel_size == 0
    ), "Num experts should be multiple of mp size"
    num_experts = args.num_experts // args.model_parallel_size
    fmoe = FMoETransformerMLP(
        num_experts,
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Sengxian's avatar
Sengxian committed
13
14
15
16
        world_size=args.world_size,
        model_parallel_size=args.model_parallel_size,
        model_parallel_rank=model_parallel_rank,
        group=group,
17
    )
Rick Ho's avatar
Rick Ho committed
18
    return fmoe