megatron.py 1.05 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
Rick Ho's avatar
Rick Ho committed
2
3


Sengxian's avatar
Sengxian committed
4
def create_moe_mlp(args, model_parallel_rank, group):
5
    assert (
6
        args.seq_length * args.batch_size % args.model_parallel_size == 0
Rick Ho's avatar
fmoefy  
Rick Ho committed
7
    ), "Batch size x sequence length should be multiple of mp size"
8
    fmoe = FMoETransformerMLP(
Rick Ho's avatar
fmoefy  
Rick Ho committed
9
        args.num_experts,
10
11
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Sengxian's avatar
Sengxian committed
12
13
14
        world_size=args.world_size,
        model_parallel_size=args.model_parallel_size,
        model_parallel_rank=model_parallel_rank,
Rick Ho's avatar
fmoefy  
Rick Ho committed
15
        mp_group=group,
16
    )
Rick Ho's avatar
Rick Ho committed
17
    return fmoe
Rick Ho's avatar
fmoefy  
Rick Ho committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def fmoefy(model, num_experts=None):
    from megatron import get_args
    from megatron import mpu
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
    for l in model.language_model.transformer.layers:
        l.mlp = create_moe_mlp(args,
                mpu.get_model_parallel_rank(),
                mpu.get_model_parallel_group())
    return model