megatron.py 422 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
from torch import nn
from .moe import FFFN


def create_moe_mlp(args):
    assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size'
    num_experts = args.num_experts // args.model_parallel_size 
Rick Ho's avatar
Rick Ho committed
8
9
10
11
    fmoe = FFFN(num_experts, 
            d_model=args.hidden_size, 
            d_hidden=args.hidden_size * 4,
            world_size=args.model_parallel_size)
Rick Ho's avatar
Rick Ho committed
12
13
    return fmoe