distributed.py 4.37 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Supportive modules to conduct distributed training
Sengxian's avatar
Sengxian committed
3
"""
4
5
6
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
from .utils import get_torch_default_comm
8
9
10


class DistributedGroupedDataParallel(nn.Module):
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
17
18
19
20
21
22
    A customized DDP module to support different all-reduce regions in the
    model.  The all-reduce region is defined as an attribution `dp_comm` in the
    weight object.
    The grads of the weights are identified to be reduced in different groups
    according to the weigths' `dp_comm` attribute.
    If it is set to `dp`, it will only be reduced across the data-parallel
    groups, which means that in the model parallel group, they are not
    synchronized.
    If it is set to `world`, the gradients is synchronized across all workers,
    regardless their model or data parallel group. This is extremely useful for
    shared layers like the gate.
Sengxian's avatar
Sengxian committed
23
24
25
26
27
28
29
    """

    def __init__(
        self,
        module,
        mp_group=None,
        dp_group=None,
Rick Ho's avatar
Rick Ho committed
30
        moe_group=None,
Sengxian's avatar
Sengxian committed
31
32
33
34
        world_group=None,
        auto_allreduce=False,
    ):
        assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
35

Rick Ho's avatar
Rick Ho committed
36
        super().__init__()
37
38
39
40
        self.module = module

        self.comms = dict()
        if mp_group is not None:
Sengxian's avatar
Sengxian committed
41
            self.comms["mp"] = mp_group
42
        if dp_group is not None:
Sengxian's avatar
Sengxian committed
43
            self.comms["dp"] = dp_group
44
        else:
Sengxian's avatar
Sengxian committed
45
            self.comms["dp"] = get_torch_default_comm()
Rick Ho's avatar
Rick Ho committed
46
47
48
49
        if moe_group is not None:
            self.comms["moe"] = moe_group
        else:
            self.comms["moe"] = get_torch_default_comm()
50
        if world_group is None:
Sengxian's avatar
Sengxian committed
51
            self.comms["world"] = get_torch_default_comm()
52
        else:
Sengxian's avatar
Sengxian committed
53
            self.comms["world"] = world_group
54

Rick Ho's avatar
Rick Ho committed
55
56
        def allreduce_params(no_scale=False,
                reduce_after=False, fp32_allreduce=False):
57
58
59
60
            groups = dict()
            for p in self.module.parameters():
                if not p.requires_grad or p.grad is None:
                    continue
Sengxian's avatar
Sengxian committed
61
                if hasattr(p, "dp_comm"):
Rick Ho's avatar
Rick Ho committed
62
                    dp_comm = p.dp_comm
63
                else:
Sengxian's avatar
Sengxian committed
64
                    dp_comm = "dp"
Rick Ho's avatar
Rick Ho committed
65
                group_key = (dp_comm, p.dtype)
66
67
68
69
                if group_key not in groups:
                    groups[group_key] = [p]
                else:
                    groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
70
            for (dp_comm, dtype), group in groups.items():
Rick Ho's avatar
Rick Ho committed
71
                if dp_comm not in self.comms:
72
                    continue
Rick Ho's avatar
Rick Ho committed
73
                comm = self.comms[dp_comm]
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                grads = [p.grad.data for p in group]
                coalesced = _flatten_dense_tensors(grads)
                if fp32_allreduce and dtype != torch.float32:
                    coalesced = coalesced.float()
                if not no_scale and not reduce_after:
                    coalesced /= comm.size()
                torch.distributed.all_reduce(coalesced, group=comm)
                torch.cuda.synchronize()
                if not no_scale and reduce_after:
                    coalesced /= comm.size()
                synced = _unflatten_dense_tensors(coalesced, grads)
                for g, s in zip(grads, synced):
                    g.copy_(s)

        self.allreduce_params = allreduce_params
Rick Ho's avatar
Rick Ho committed
89
        self._sync_params()
90

Rick Ho's avatar
Rick Ho committed
91
92
93
94
95
    def _sync_params(self):
        groups = dict()
        for p in self.module.parameters():
            if not p.requires_grad or p.grad is None:
                continue
Sengxian's avatar
Sengxian committed
96
            if hasattr(p, "dp_comm"):
Rick Ho's avatar
Rick Ho committed
97
98
                dp_comm = p.dp_comm
            else:
Sengxian's avatar
Sengxian committed
99
                dp_comm = "dp"
Rick Ho's avatar
Rick Ho committed
100
101
102
103
104
            group_key = (dp_comm, p.dtype)
            if group_key not in groups:
                groups[group_key] = [p]
            else:
                groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
105
        for (dp_comm, _), group in groups.items():
Rick Ho's avatar
Rick Ho committed
106
107
108
109
110
111
112
113
114
115
            if dp_comm not in self.comms:
                continue
            comm = self.comms[dp_comm]
            datas = [p.data for p in group]
            coalesced = _flatten_dense_tensors(datas)
            torch.distributed.broadcast(coalesced, 0, group=comm)
            torch.cuda.synchronize()
            synced = _unflatten_dense_tensors(coalesced, datas)
            for d, s in zip(datas, synced):
                d.copy_(s)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
116

117
    def forward(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
118
        r"""
Rick Ho's avatar
Rick Ho committed
119
        Directly call the module's forward function.
Sengxian's avatar
Sengxian committed
120
        """
121
        return self.module(*args, **kwargs)