megatron.py 3.99 KB
Newer Older
1
r'''
Rick Ho's avatar
Rick Ho committed
2
3
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
Rick Ho's avatar
Rick Ho committed
4
See `examples/megatron` for usage instructions.
Rick Ho's avatar
Rick Ho committed
5
'''
Rick Ho's avatar
Rick Ho committed
6
7
8
import torch

from .transformer import FMoETransformerMLP
9
from .distributed import DistributedGroupedDataParallel
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12

Rick Ho's avatar
Rick Ho committed
13
class MegatronMLP(FMoETransformerMLP):
Rick Ho's avatar
Rick Ho committed
14
15
16
17
    r'''
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
    '''
Rick Ho's avatar
Rick Ho committed
18
19
20
21
22
23
24
25
26
    def __init__(self, args, group):
        assert (args.seq_length * args.micro_batch_size
                % args.tensor_model_parallel_size == 0
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
        super().__init__(args.num_experts,
27
                top_k=args.top_k,
28
                d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
Rick Ho's avatar
Rick Ho committed
29
30
31
32
33
34
35
                world_size=world_size, mp_group=group)
        self.bias = torch.nn.parameter.Parameter(
            torch.zeros(args.hidden_size, dtype=torch.float32)
        )

    def forward(self, inp):
        return super().forward(inp), self.bias
Rick Ho's avatar
fmoefy  
Rick Ho committed
36
37


38
def fmoefy(model, num_experts=None, distributed_experts=True,
39
        hidden_hidden_size=None, top_k=None):
Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
46
47
48
49
    r'''
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
50
51
52
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Rick Ho's avatar
Rick Ho committed
53
    '''
Rick Ho's avatar
fmoefy  
Rick Ho committed
54
55
56
57
58
59
60
    from megatron import get_args
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
61

62
63
64
65
66
    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
    elif not hasattr(args, 'hidden_hidden_size'):
        args.hidden_hidden_size = args.hidden_size * 4

67
68
69
70
71
    if top_k is not None:
        args.top_k = top_k
    elif not hasattr(args, 'top_k'):
        args.top_k = 2

72
73
74
75
    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Rick Ho's avatar
fmoefy  
Rick Ho committed
76
    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
77
        l.mlp = MegatronMLP(args, get_torch_default_comm())
Rick Ho's avatar
fmoefy  
Rick Ho committed
78
    return model
79
80
81


class DistributedDataParallel(DistributedGroupedDataParallel):
Rick Ho's avatar
Rick Ho committed
82
83
84
85
86
    r'''
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
    '''
87
88
    def __init__(self, module):
        from megatron import mpu
Rick Ho's avatar
Rick Ho committed
89
        super().__init__(
90
91
92
93
            module,
            mp_group=mpu.get_model_parallel_group(),
            dp_group=mpu.get_data_parallel_group()
        )
94
95

    def state_dict(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
96
97
98
        r'''
        Keep consitency with Megatron
        '''
99
100
101
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
102
103
104
        r'''
        Keep consitency with Megatron
        '''
105
106
107
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
108
109
110
        r'''
        Keep consitency with Megatron
        '''
111
        return self.module.load_state_dict(*args, **kwargs)