megatron.py 1.78 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
2
from .distributed import DistributedGroupedDataParallel
Rick Ho's avatar
Rick Ho committed
3

Sengxian's avatar
Sengxian committed
4
def create_moe_mlp(args, model_parallel_rank, group):
5
    assert (
6
        args.seq_length * args.batch_size % args.model_parallel_size == 0
Rick Ho's avatar
fmoefy  
Rick Ho committed
7
    ), "Batch size x sequence length should be multiple of mp size"
8
    if not args.distributed_experts:
Rick Ho's avatar
Rick Ho committed
9
10
11
        world_size = 1
    else:
        world_size = args.world_size
12
    fmoe = FMoETransformerMLP(
Rick Ho's avatar
fmoefy  
Rick Ho committed
13
        args.num_experts,
14
15
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Rick Ho's avatar
Rick Ho committed
16
        world_size=world_size,
Sengxian's avatar
Sengxian committed
17
18
        model_parallel_size=args.model_parallel_size,
        model_parallel_rank=model_parallel_rank,
Rick Ho's avatar
fmoefy  
Rick Ho committed
19
        mp_group=group,
20
    )
Rick Ho's avatar
Rick Ho committed
21
22
    for p in fmoe.gate.parameters():
        setattr(p, 'shared', True)
Rick Ho's avatar
Rick Ho committed
23
    return fmoe
Rick Ho's avatar
fmoefy  
Rick Ho committed
24
25


26
def fmoefy(model, num_experts=None, distributed_experts=True):
Rick Ho's avatar
fmoefy  
Rick Ho committed
27
28
29
30
31
32
33
34
    from megatron import get_args
    from megatron import mpu
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
35
36
37
38
39

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Rick Ho's avatar
fmoefy  
Rick Ho committed
40
41
42
43
44
    for l in model.language_model.transformer.layers:
        l.mlp = create_moe_mlp(args,
                mpu.get_model_parallel_rank(),
                mpu.get_model_parallel_group())
    return model
45
46
47
48
49
50
51
52
53
54


class DistributedDataParallel(DistributedGroupedDataParallel):
    def __init__(self, module):
        from megatron import mpu
        super(DistributedDataParallel, self).__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
            dp_group=mpu.get_data_parallel_group()
        )