megatron.py 497 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
Rick Ho's avatar
Rick Ho committed
2
3
4


def create_moe_mlp(args):
5
6
7
8
9
10
11
12
13
14
15
    assert (
        args.num_experts % args.model_parallel_size == 0
    ), "Num experts should be multiple of mp size"
    num_experts = args.num_experts // args.model_parallel_size
    fmoe = FMoETransformerMLP(
        num_experts,
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
        world_size=args.model_parallel_size,
        model_parallel_rank=args.model_parallel_rank,
    )
Rick Ho's avatar
Rick Ho committed
16
    return fmoe