megatron.py 3.22 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
'''
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `exapmles/megatron` for usage instructions.
'''
Rick Ho's avatar
Rick Ho committed
6
from .layers import FMoETransformerMLP
7
from .distributed import DistributedGroupedDataParallel
Rick Ho's avatar
Rick Ho committed
8

Rick Ho's avatar
Rick Ho committed
9

Rick Ho's avatar
Rick Ho committed
10
11
12
13
14
15
16
def _create_moe_mlp(args, group):
    r'''
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
    '''
    assert (args.seq_length * args.micro_batch_size
            % args.tensor_model_parallel_size == 0
Rick Ho's avatar
fmoefy  
Rick Ho committed
17
    ), "Batch size x sequence length should be multiple of mp size"
18
    if not args.distributed_experts:
Rick Ho's avatar
Rick Ho committed
19
20
21
        world_size = 1
    else:
        world_size = args.world_size
22
    fmoe = FMoETransformerMLP(
Rick Ho's avatar
fmoefy  
Rick Ho committed
23
        args.num_experts,
24
25
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Rick Ho's avatar
Rick Ho committed
26
        world_size=world_size,
Rick Ho's avatar
Rick Ho committed
27
        mp_group=group
28
    )
Rick Ho's avatar
Rick Ho committed
29
30
    for p in fmoe.gate.parameters():
        setattr(p, 'shared', True)
Rick Ho's avatar
Rick Ho committed
31
    return fmoe
Rick Ho's avatar
fmoefy  
Rick Ho committed
32
33


34
def fmoefy(model, num_experts=None, distributed_experts=True):
Rick Ho's avatar
Rick Ho committed
35
36
37
38
39
40
41
42
43
44
45
    r'''
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    '''
Rick Ho's avatar
fmoefy  
Rick Ho committed
46
47
48
49
50
51
52
53
    from megatron import get_args
    from megatron import mpu
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
54
55
56
57
58

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Rick Ho's avatar
fmoefy  
Rick Ho committed
59
    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
60
        l.mlp = _create_moe_mlp(args, mpu.get_model_parallel_group())
Rick Ho's avatar
fmoefy  
Rick Ho committed
61
    return model
62
63
64


class DistributedDataParallel(DistributedGroupedDataParallel):
Rick Ho's avatar
Rick Ho committed
65
66
67
68
69
    r'''
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
    '''
70
71
    def __init__(self, module):
        from megatron import mpu
Rick Ho's avatar
Rick Ho committed
72
        super().__init__(
73
74
75
76
            module,
            mp_group=mpu.get_model_parallel_group(),
            dp_group=mpu.get_data_parallel_group()
        )
77
78

    def state_dict(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
79
80
81
        r'''
        Keep consitency with Megatron
        '''
82
83
84
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
85
86
87
        r'''
        Keep consitency with Megatron
        '''
88
89
90
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
91
92
93
        r'''
        Keep consitency with Megatron
        '''
94
        return self.module.load_state_dict(*args, **kwargs)