distributed.py 2.21 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
distributed support for Megatron
"""
Rick Ho's avatar
Rick Ho committed
4
5
import torch

Rick Ho's avatar
Rick Ho committed
6
7
8
from fmoe.distributed import DistributedGroupedDataParallel


Rick Ho's avatar
Rick Ho committed
9
10
11
12
13
14
_groups = None


def _set_groups(**kwargs):
    global _groups
    _groups = kwargs
Rick Ho's avatar
Rick Ho committed
15
16


zms1999's avatar
zms1999 committed
17
18
19
20
def get_moe_group():
    return _groups["moe_group"]


Rick Ho's avatar
Rick Ho committed
21
22
23
24
def _init():
    from megatron import get_args
    from megatron import mpu
    args = get_args()
Rick Ho's avatar
Rick Ho committed
25

Rick Ho's avatar
Rick Ho committed
26
    # Create a comm prependicular to the pipeline group as gate group
Rick Ho's avatar
Rick Ho committed
27
28
29
30
31
32
    stage_size = args.world_size // args.pipeline_model_parallel_size
    for i in range(0, args.world_size, stage_size):
        ranks = range(i, i + stage_size)
        group = torch.distributed.new_group(ranks)
        if args.rank in ranks:
            gate_group = group
Rick Ho's avatar
Rick Ho committed
33

Rick Ho's avatar
Rick Ho committed
34
35
36
37
    _set_groups(
            dp_group=mpu.get_data_parallel_group(),
            moe_group=mpu.get_data_parallel_group(),
            gate_group=gate_group)
38
39


Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
class DistributedDataParallel(DistributedGroupedDataParallel):
    r"""
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
    """
zms1999's avatar
zms1999 committed
46
47
48
49
50
    
    def __init__(self, module, accumulate_allreduce_grads_in_fp32=False, use_contiguous_buffers_in_ddp=False):
        assert not accumulate_allreduce_grads_in_fp32, "FastMoE not supports accumulate_allrecude_grads_in_fp32"
        assert not use_contiguous_buffers_in_ddp, "FastMoE not supports use_contiguous_buffers_in_ddp"
        
Rick Ho's avatar
Rick Ho committed
51
52
53
        if _groups is None:
            _init()
        super().__init__(module, **_groups)
Rick Ho's avatar
Rick Ho committed
54

zms1999's avatar
zms1999 committed
55
56
57
58
59
60
    def set_input_tensor(self, *args, **kwargs):
        r"""
        Keep consitency with Megatron
        """
        return self.module.set_input_tensor(*args, **kwargs)

Rick Ho's avatar
Rick Ho committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def state_dict(self, *args, **kwargs):
        r"""
        Keep consitency with Megatron
        """
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
        r"""
        Keep consitency with Megatron
        """
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        r"""
        Keep consitency with Megatron
        """
        return self.module.load_state_dict(*args, **kwargs)