megatron.py 1.15 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
Rick Ho's avatar
Rick Ho committed
2
3


Sengxian's avatar
Sengxian committed
4
def create_moe_mlp(args, model_parallel_rank, group):
5
    assert (
6
        args.seq_length * args.batch_size % args.model_parallel_size == 0
Rick Ho's avatar
fmoefy  
Rick Ho committed
7
    ), "Batch size x sequence length should be multiple of mp size"
Rick Ho's avatar
Rick Ho committed
8
9
10
11
    if args.model_parallel_size == 1:
        world_size = 1
    else:
        world_size = args.world_size
12
    fmoe = FMoETransformerMLP(
Rick Ho's avatar
fmoefy  
Rick Ho committed
13
        args.num_experts,
14
15
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Rick Ho's avatar
Rick Ho committed
16
        world_size=world_size,
Sengxian's avatar
Sengxian committed
17
18
        model_parallel_size=args.model_parallel_size,
        model_parallel_rank=model_parallel_rank,
Rick Ho's avatar
fmoefy  
Rick Ho committed
19
        mp_group=group,
20
    )
Rick Ho's avatar
Rick Ho committed
21
    return fmoe
Rick Ho's avatar
fmoefy  
Rick Ho committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def fmoefy(model, num_experts=None):
    from megatron import get_args
    from megatron import mpu
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
    for l in model.language_model.transformer.layers:
        l.mlp = create_moe_mlp(args,
                mpu.get_model_parallel_rank(),
                mpu.get_model_parallel_group())
    return model