distributed.py 4.19 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Supportive modules to conduct distributed training
Sengxian's avatar
Sengxian committed
3
"""
4
5
6
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
from .utils import get_torch_default_comm
8
9
10


class DistributedGroupedDataParallel(nn.Module):
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
17
18
19
20
21
22
    A customized DDP module to support different all-reduce regions in the
    model.  The all-reduce region is defined as an attribution `dp_comm` in the
    weight object.
    The grads of the weights are identified to be reduced in different groups
    according to the weigths' `dp_comm` attribute.
    If it is set to `dp`, it will only be reduced across the data-parallel
    groups, which means that in the model parallel group, they are not
    synchronized.
    If it is set to `world`, the gradients is synchronized across all workers,
    regardless their model or data parallel group. This is extremely useful for
    shared layers like the gate.
Sengxian's avatar
Sengxian committed
23
24
25
26
27
28
29
30
31
32
33
    """

    def __init__(
        self,
        module,
        mp_group=None,
        dp_group=None,
        world_group=None,
        auto_allreduce=False,
    ):
        assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
34

Rick Ho's avatar
Rick Ho committed
35
        super().__init__()
36
37
38
39
        self.module = module

        self.comms = dict()
        if mp_group is not None:
Sengxian's avatar
Sengxian committed
40
            self.comms["mp"] = mp_group
41
        if dp_group is not None:
Sengxian's avatar
Sengxian committed
42
            self.comms["dp"] = dp_group
43
        else:
Sengxian's avatar
Sengxian committed
44
            self.comms["dp"] = get_torch_default_comm()
45
        if world_group is None:
Sengxian's avatar
Sengxian committed
46
            self.comms["world"] = get_torch_default_comm()
47
        else:
Sengxian's avatar
Sengxian committed
48
            self.comms["world"] = world_group
49

Sengxian's avatar
Sengxian committed
50
        def allreduce_params(no_scale=False, reduce_after=False, fp32_allreduce=False):
51
52
53
54
            groups = dict()
            for p in self.module.parameters():
                if not p.requires_grad or p.grad is None:
                    continue
Sengxian's avatar
Sengxian committed
55
                if hasattr(p, "dp_comm"):
Rick Ho's avatar
Rick Ho committed
56
                    dp_comm = p.dp_comm
57
                else:
Sengxian's avatar
Sengxian committed
58
                    dp_comm = "dp"
Rick Ho's avatar
Rick Ho committed
59
                group_key = (dp_comm, p.dtype)
60
61
62
63
                if group_key not in groups:
                    groups[group_key] = [p]
                else:
                    groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
64
            for (dp_comm, dtype), group in groups.items():
Rick Ho's avatar
Rick Ho committed
65
                if dp_comm not in self.comms:
66
                    continue
Rick Ho's avatar
Rick Ho committed
67
                comm = self.comms[dp_comm]
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                grads = [p.grad.data for p in group]
                coalesced = _flatten_dense_tensors(grads)
                if fp32_allreduce and dtype != torch.float32:
                    coalesced = coalesced.float()
                if not no_scale and not reduce_after:
                    coalesced /= comm.size()
                torch.distributed.all_reduce(coalesced, group=comm)
                torch.cuda.synchronize()
                if not no_scale and reduce_after:
                    coalesced /= comm.size()
                synced = _unflatten_dense_tensors(coalesced, grads)
                for g, s in zip(grads, synced):
                    g.copy_(s)

        self.allreduce_params = allreduce_params
Rick Ho's avatar
Rick Ho committed
83
        self._sync_params()
84

Rick Ho's avatar
Rick Ho committed
85
86
87
88
89
    def _sync_params(self):
        groups = dict()
        for p in self.module.parameters():
            if not p.requires_grad or p.grad is None:
                continue
Sengxian's avatar
Sengxian committed
90
            if hasattr(p, "dp_comm"):
Rick Ho's avatar
Rick Ho committed
91
92
                dp_comm = p.dp_comm
            else:
Sengxian's avatar
Sengxian committed
93
                dp_comm = "dp"
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
            group_key = (dp_comm, p.dtype)
            if group_key not in groups:
                groups[group_key] = [p]
            else:
                groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
99
        for (dp_comm, _), group in groups.items():
Rick Ho's avatar
Rick Ho committed
100
101
102
103
104
105
106
107
108
109
            if dp_comm not in self.comms:
                continue
            comm = self.comms[dp_comm]
            datas = [p.data for p in group]
            coalesced = _flatten_dense_tensors(datas)
            torch.distributed.broadcast(coalesced, 0, group=comm)
            torch.cuda.synchronize()
            synced = _unflatten_dense_tensors(coalesced, datas)
            for d, s in zip(datas, synced):
                d.copy_(s)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
110

111
    def forward(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
112
        r"""
Rick Ho's avatar
Rick Ho committed
113
        Directly call the module's forward function.
Sengxian's avatar
Sengxian committed
114
        """
115
        return self.module(*args, **kwargs)