megatron.py 432 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
Rick Ho's avatar
Rick Ho committed
2
3
4
5
6


def create_moe_mlp(args):
    assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size'
    num_experts = args.num_experts // args.model_parallel_size 
Rick Ho's avatar
Rick Ho committed
7
    fmoe = FMoETransformerMLP(num_experts, 
Rick Ho's avatar
Rick Ho committed
8
9
10
            d_model=args.hidden_size, 
            d_hidden=args.hidden_size * 4,
            world_size=args.model_parallel_size)
Rick Ho's avatar
Rick Ho committed
11
12
    return fmoe