distributed.py 4.16 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r'''
Supportive modules to conduct distributed training
'''
4
5
6
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
from .utils import get_torch_default_comm
8
9
10


class DistributedGroupedDataParallel(nn.Module):
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
17
18
19
20
21
22
23
    r'''
    A customized DDP module to support different all-reduce regions in the
    model.  The all-reduce region is defined as an attribution `dp_comm` in the
    weight object.
    The grads of the weights are identified to be reduced in different groups
    according to the weigths' `dp_comm` attribute.
    If it is set to `dp`, it will only be reduced across the data-parallel
    groups, which means that in the model parallel group, they are not
    synchronized.
    If it is set to `world`, the gradients is synchronized across all workers,
    regardless their model or data parallel group. This is extremely useful for
    shared layers like the gate.
    '''
24
25
26
27
    def __init__(self, module, mp_group=None, dp_group=None, world_group=None,
            auto_allreduce=False):
        assert not auto_allreduce, 'Automatic all-reduce is not implemented yet'

Rick Ho's avatar
Rick Ho committed
28
        super().__init__()
29
30
31
32
33
34
35
36
        self.module = module

        self.comms = dict()
        if mp_group is not None:
            self.comms['mp'] = mp_group
        if dp_group is not None:
            self.comms['dp'] = dp_group
        else:
37
            self.comms['dp'] = get_torch_default_comm()
38
        if world_group is None:
39
            self.comms['world'] = get_torch_default_comm()
40
41
42
        else:
            self.comms['world'] = world_group

Rick Ho's avatar
Rick Ho committed
43
        def allreduce_params(no_scale=False, reduce_after=False,
44
45
46
47
48
                fp32_allreduce=False):
            groups = dict()
            for p in self.module.parameters():
                if not p.requires_grad or p.grad is None:
                    continue
Rick Ho's avatar
Rick Ho committed
49
50
                if hasattr(p, 'dp_comm'):
                    dp_comm = p.dp_comm
51
                else:
Rick Ho's avatar
Rick Ho committed
52
53
                    dp_comm = 'dp'
                group_key = (dp_comm, p.dtype)
54
55
56
57
                if group_key not in groups:
                    groups[group_key] = [p]
                else:
                    groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
58
            for (dp_comm, dtype), group in groups.items():
Rick Ho's avatar
Rick Ho committed
59
                if dp_comm not in self.comms:
60
                    continue
Rick Ho's avatar
Rick Ho committed
61
                comm = self.comms[dp_comm]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
                grads = [p.grad.data for p in group]
                coalesced = _flatten_dense_tensors(grads)
                if fp32_allreduce and dtype != torch.float32:
                    coalesced = coalesced.float()
                if not no_scale and not reduce_after:
                    coalesced /= comm.size()
                torch.distributed.all_reduce(coalesced, group=comm)
                torch.cuda.synchronize()
                if not no_scale and reduce_after:
                    coalesced /= comm.size()
                synced = _unflatten_dense_tensors(coalesced, grads)
                for g, s in zip(grads, synced):
                    g.copy_(s)

        self.allreduce_params = allreduce_params
Rick Ho's avatar
Rick Ho committed
77
        self._sync_params()
78

Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    def _sync_params(self):
        groups = dict()
        for p in self.module.parameters():
            if not p.requires_grad or p.grad is None:
                continue
            if hasattr(p, 'dp_comm'):
                dp_comm = p.dp_comm
            else:
                dp_comm = 'dp'
            group_key = (dp_comm, p.dtype)
            if group_key not in groups:
                groups[group_key] = [p]
            else:
                groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
93
        for (dp_comm, _), group in groups.items():
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
99
100
101
102
103
            if dp_comm not in self.comms:
                continue
            comm = self.comms[dp_comm]
            datas = [p.data for p in group]
            coalesced = _flatten_dense_tensors(datas)
            torch.distributed.broadcast(coalesced, 0, group=comm)
            torch.cuda.synchronize()
            synced = _unflatten_dense_tensors(coalesced, datas)
            for d, s in zip(datas, synced):
                d.copy_(s)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
104

105
    def forward(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
106
107
108
        r'''
        Directly call the module's forward function.
        '''
109
        return self.module(*args, **kwargs)