megatron.py 1.59 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .layers import FMoETransformerMLP
2
from .distributed import DistributedGroupedDataParallel
Rick Ho's avatar
Rick Ho committed
3

Rick Ho's avatar
Rick Ho committed
4
5

def create_moe_mlp(args, group):
6
    assert (
7
        args.seq_length * args.batch_size % args.model_parallel_size == 0
Rick Ho's avatar
fmoefy  
Rick Ho committed
8
    ), "Batch size x sequence length should be multiple of mp size"
9
    if not args.distributed_experts:
Rick Ho's avatar
Rick Ho committed
10
11
12
        world_size = 1
    else:
        world_size = args.world_size
13
    fmoe = FMoETransformerMLP(
Rick Ho's avatar
fmoefy  
Rick Ho committed
14
        args.num_experts,
15
16
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Rick Ho's avatar
Rick Ho committed
17
        world_size=world_size,
Rick Ho's avatar
Rick Ho committed
18
        mp_group=group
19
    )
Rick Ho's avatar
Rick Ho committed
20
21
    for p in fmoe.gate.parameters():
        setattr(p, 'shared', True)
Rick Ho's avatar
Rick Ho committed
22
    return fmoe
Rick Ho's avatar
fmoefy  
Rick Ho committed
23
24


25
def fmoefy(model, num_experts=None, distributed_experts=True):
Rick Ho's avatar
fmoefy  
Rick Ho committed
26
27
28
29
30
31
32
33
    from megatron import get_args
    from megatron import mpu
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
34
35
36
37
38

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Rick Ho's avatar
fmoefy  
Rick Ho committed
39
    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
40
        l.mlp = create_moe_mlp(args, mpu.get_model_parallel_group())
Rick Ho's avatar
fmoefy  
Rick Ho committed
41
    return model
42
43
44
45
46
47
48
49
50
51


class DistributedDataParallel(DistributedGroupedDataParallel):
    def __init__(self, module):
        from megatron import mpu
        super(DistributedDataParallel, self).__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
            dp_group=mpu.get_data_parallel_group()
        )