"server/requirements_rocm.txt" did not exist on "16fadcec5711ff232977b38c74a1c8829af6a63b"
megatron.py 3.43 KB
Newer Older
1
r'''
Rick Ho's avatar
Rick Ho committed
2
3
4
5
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `exapmles/megatron` for usage instructions.
'''
Rick Ho's avatar
Rick Ho committed
6
from .layers import FMoETransformerMLP
7
from .distributed import DistributedGroupedDataParallel
8
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
9

Rick Ho's avatar
Rick Ho committed
10

Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
17
def _create_moe_mlp(args, group):
    r'''
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
    '''
    assert (args.seq_length * args.micro_batch_size
            % args.tensor_model_parallel_size == 0
Rick Ho's avatar
fmoefy  
Rick Ho committed
18
    ), "Batch size x sequence length should be multiple of mp size"
19
    if not args.distributed_experts:
Rick Ho's avatar
Rick Ho committed
20
21
22
        world_size = 1
    else:
        world_size = args.world_size
23
    fmoe = FMoETransformerMLP(
Rick Ho's avatar
fmoefy  
Rick Ho committed
24
        args.num_experts,
25
26
        d_model=args.hidden_size,
        d_hidden=args.hidden_size * 4,
Rick Ho's avatar
Rick Ho committed
27
        world_size=world_size,
Rick Ho's avatar
Rick Ho committed
28
        mp_group=group
29
    )
Rick Ho's avatar
Rick Ho committed
30
31
    for p in fmoe.gate.parameters():
        setattr(p, 'shared', True)
Rick Ho's avatar
Rick Ho committed
32
    return fmoe
Rick Ho's avatar
fmoefy  
Rick Ho committed
33
34


35
def fmoefy(model, num_experts=None, distributed_experts=True):
Rick Ho's avatar
Rick Ho committed
36
37
38
39
40
41
42
43
44
45
    r'''
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
46
47
48
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Rick Ho's avatar
Rick Ho committed
49
    '''
Rick Ho's avatar
fmoefy  
Rick Ho committed
50
51
52
53
54
55
56
    from megatron import get_args
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        'num_experts' in args
    ), 'num_experts should be specified in arguments or fmoefy function'
57
58
59
60
61

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Rick Ho's avatar
fmoefy  
Rick Ho committed
62
    for l in model.language_model.transformer.layers:
63
        l.mlp = _create_moe_mlp(args, get_torch_default_comm())
Rick Ho's avatar
fmoefy  
Rick Ho committed
64
    return model
65
66
67


class DistributedDataParallel(DistributedGroupedDataParallel):
Rick Ho's avatar
Rick Ho committed
68
69
70
71
72
    r'''
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
    '''
73
74
    def __init__(self, module):
        from megatron import mpu
Rick Ho's avatar
Rick Ho committed
75
        super().__init__(
76
77
78
79
            module,
            mp_group=mpu.get_model_parallel_group(),
            dp_group=mpu.get_data_parallel_group()
        )
80
81

    def state_dict(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
82
83
84
        r'''
        Keep consitency with Megatron
        '''
85
86
87
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
88
89
90
        r'''
        Keep consitency with Megatron
        '''
91
92
93
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
94
95
96
        r'''
        Keep consitency with Megatron
        '''
97
        return self.module.load_state_dict(*args, **kwargs)