distributed.py 4.03 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Supportive modules to conduct distributed training
Sengxian's avatar
Sengxian committed
3
"""
4
5
6
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
from .utils import get_torch_default_comm, get_rank_0_in_comm
8
9
10


class DistributedGroupedDataParallel(nn.Module):
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
17
18
19
20
21
22
    A customized DDP module to support different all-reduce regions in the
    model.  The all-reduce region is defined as an attribution `dp_comm` in the
    weight object.
    The grads of the weights are identified to be reduced in different groups
    according to the weigths' `dp_comm` attribute.
    If it is set to `dp`, it will only be reduced across the data-parallel
    groups, which means that in the model parallel group, they are not
    synchronized.
    If it is set to `world`, the gradients is synchronized across all workers,
    regardless their model or data parallel group. This is extremely useful for
    shared layers like the gate.
Sengxian's avatar
Sengxian committed
23
24
25
26
27
28
    """

    def __init__(
        self,
        module,
        auto_allreduce=False,
Rick Ho's avatar
Rick Ho committed
29
        need_sync=True,
Rick Ho's avatar
Rick Ho committed
30
        **kwargs
Sengxian's avatar
Sengxian committed
31
32
    ):
        assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
33

Rick Ho's avatar
Rick Ho committed
34
        super().__init__()
35
36
37
        self.module = module

        self.comms = dict()
Rick Ho's avatar
Rick Ho committed
38
39
40
41
42
43
        for k in kwargs:
            if k.endswith('_group'):
                self.comms[k[:-6]] = kwargs[k]
        for k in ['dp', 'gate', 'moe', 'world']:
            if k not in self.comms:
                self.comms[k] = get_torch_default_comm()
44

Rick Ho's avatar
Rick Ho committed
45
46
        def allreduce_params(no_scale=False,
                reduce_after=False, fp32_allreduce=False):
47
48
49
50
            groups = dict()
            for p in self.module.parameters():
                if not p.requires_grad or p.grad is None:
                    continue
Sengxian's avatar
Sengxian committed
51
                if hasattr(p, "dp_comm"):
Rick Ho's avatar
Rick Ho committed
52
                    dp_comm = p.dp_comm
53
                else:
Sengxian's avatar
Sengxian committed
54
                    dp_comm = "dp"
Rick Ho's avatar
Rick Ho committed
55
                group_key = (dp_comm, p.dtype)
56
57
58
59
                if group_key not in groups:
                    groups[group_key] = [p]
                else:
                    groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
60
            for (dp_comm, dtype), group in groups.items():
Rick Ho's avatar
Rick Ho committed
61
                if dp_comm not in self.comms:
62
                    continue
Rick Ho's avatar
Rick Ho committed
63
                comm = self.comms[dp_comm]
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                grads = [p.grad.data for p in group]
                coalesced = _flatten_dense_tensors(grads)
                if fp32_allreduce and dtype != torch.float32:
                    coalesced = coalesced.float()
                if not no_scale and not reduce_after:
                    coalesced /= comm.size()
                torch.distributed.all_reduce(coalesced, group=comm)
                if not no_scale and reduce_after:
                    coalesced /= comm.size()
                synced = _unflatten_dense_tensors(coalesced, grads)
                for g, s in zip(grads, synced):
                    g.copy_(s)

        self.allreduce_params = allreduce_params
Rick Ho's avatar
Rick Ho committed
78
79
        if need_sync:
            self._sync_params()
80

Rick Ho's avatar
Rick Ho committed
81
82
83
    def _sync_params(self):
        groups = dict()
        for p in self.module.parameters():
Sengxian's avatar
Sengxian committed
84
            if hasattr(p, "dp_comm"):
Rick Ho's avatar
Rick Ho committed
85
86
                dp_comm = p.dp_comm
            else:
Sengxian's avatar
Sengxian committed
87
                dp_comm = "dp"
Rick Ho's avatar
Rick Ho committed
88
89
90
91
92
            group_key = (dp_comm, p.dtype)
            if group_key not in groups:
                groups[group_key] = [p]
            else:
                groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
93
        for (dp_comm, _), group in groups.items():
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
            if dp_comm not in self.comms:
                continue
            comm = self.comms[dp_comm]
            datas = [p.data for p in group]
            coalesced = _flatten_dense_tensors(datas)
99
100
            torch.distributed.broadcast(coalesced,
                    get_rank_0_in_comm(comm), group=comm)
Rick Ho's avatar
Rick Ho committed
101
102
103
104
            torch.cuda.synchronize()
            synced = _unflatten_dense_tensors(coalesced, datas)
            for d, s in zip(datas, synced):
                d.copy_(s)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
105

106
    def forward(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
107
        r"""
Rick Ho's avatar
Rick Ho committed
108
        Directly call the module's forward function.
Sengxian's avatar
Sengxian committed
109
        """
110
        return self.module(*args, **kwargs)