megatron.py 4.45 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
r'''
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm


class _MegatronMLP(nn.Module):
    def __init__(self, args, group):
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)


class MegatronMLP(FMoETransformerMLP):
    r'''
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
    '''
    def __init__(self, args, group):
        assert (args.seq_length * args.micro_batch_size
                % args.tensor_model_parallel_size == 0
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
        super().__init__(args.num_experts,
                top_k=args.top_k,
                d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
                world_size=world_size, mp_group=group,
                expert_dp_comm='none' if args.distributed_experts else 'dp')
        self.hidden_size = args.hidden_size

    def forward(self, inp):
        return super().forward(inp), torch.zeros(self.hidden_size,
                dtype=inp.dtype, device=inp.device)


def fmoefy(model, num_experts=None, distributed_experts=True,
        hidden_hidden_size=None, top_k=None):
    r'''
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
    '''
    from megatron import get_args
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
    elif not hasattr(args, 'hidden_hidden_size'):
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
    elif not hasattr(args, 'top_k'):
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
        l.mlp = MegatronMLP(args, None)
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
    r'''
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
    '''
    def __init__(self, module):
        from megatron import mpu
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
            dp_group=mpu.get_data_parallel_group()
        )

    def state_dict(self, *args, **kwargs):
        r'''
        Keep consitency with Megatron
        '''
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
        r'''
        Keep consitency with Megatron
        '''
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        r'''
        Keep consitency with Megatron
        '''
        return self.module.load_state_dict(*args, **kwargs)