distributed.py 4.27 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r'''
Supportive modules to conduct distributed training
'''
4
5
6
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
from .utils import get_torch_default_comm
8
9
10


class DistributedGroupedDataParallel(nn.Module):
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
17
18
19
20
21
22
23
    r'''
    A customized DDP module to support different all-reduce regions in the
    model.  The all-reduce region is defined as an attribution `dp_comm` in the
    weight object.
    The grads of the weights are identified to be reduced in different groups
    according to the weigths' `dp_comm` attribute.
    If it is set to `dp`, it will only be reduced across the data-parallel
    groups, which means that in the model parallel group, they are not
    synchronized.
    If it is set to `world`, the gradients is synchronized across all workers,
    regardless their model or data parallel group. This is extremely useful for
    shared layers like the gate.
    '''
24
25
26
27
    def __init__(self, module, mp_group=None, dp_group=None, world_group=None,
            auto_allreduce=False):
        assert not auto_allreduce, 'Automatic all-reduce is not implemented yet'

Rick Ho's avatar
Rick Ho committed
28
        super().__init__()
29
30
31
32
33
34
35
36
        self.module = module

        self.comms = dict()
        if mp_group is not None:
            self.comms['mp'] = mp_group
        if dp_group is not None:
            self.comms['dp'] = dp_group
        else:
37
            self.comms['dp'] = get_torch_default_comm()
38
        if world_group is None:
39
            self.comms['world'] = get_torch_default_comm()
40
41
42
        else:
            self.comms['world'] = world_group

Rick Ho's avatar
Rick Ho committed
43
        def allreduce_params(no_scale=False, reduce_after=False,
44
45
46
47
48
                fp32_allreduce=False):
            groups = dict()
            for p in self.module.parameters():
                if not p.requires_grad or p.grad is None:
                    continue
Rick Ho's avatar
Rick Ho committed
49
50
                if hasattr(p, 'dp_comm'):
                    dp_comm = p.dp_comm
51
                else:
Rick Ho's avatar
Rick Ho committed
52
53
                    dp_comm = 'dp'
                group_key = (dp_comm, p.dtype)
54
55
56
57
                if group_key not in groups:
                    groups[group_key] = [p]
                else:
                    groups[group_key].append(p)
Rick Ho's avatar
Rick Ho committed
58
            for (dp_comm, dtype), group in groups.items():
Rick Ho's avatar
Rick Ho committed
59
                if dp_comm not in self.comms:
60
                    continue
Rick Ho's avatar
Rick Ho committed
61
                comm = self.comms[dp_comm]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
                grads = [p.grad.data for p in group]
                coalesced = _flatten_dense_tensors(grads)
                if fp32_allreduce and dtype != torch.float32:
                    coalesced = coalesced.float()
                if not no_scale and not reduce_after:
                    coalesced /= comm.size()
                torch.distributed.all_reduce(coalesced, group=comm)
                torch.cuda.synchronize()
                if not no_scale and reduce_after:
                    coalesced /= comm.size()
                synced = _unflatten_dense_tensors(coalesced, grads)
                for g, s in zip(grads, synced):
                    g.copy_(s)

        self.allreduce_params = allreduce_params
Rick Ho's avatar
Rick Ho committed
77
        self._sync_params()
78

Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def _sync_params(self):
        groups = dict()
        for p in self.module.parameters():
            if not p.requires_grad or p.grad is None:
                continue
            if hasattr(p, 'dp_comm'):
                dp_comm = p.dp_comm
            else:
                dp_comm = 'dp'
            group_key = (dp_comm, p.dtype)
            if group_key not in groups:
                groups[group_key] = [p]
            else:
                groups[group_key].append(p)
        for (dp_comm, dtype), group in groups.items():
            if dp_comm not in self.comms:
                continue
            comm = self.comms[dp_comm]
            datas = [p.data for p in group]
            coalesced = _flatten_dense_tensors(datas)
            if fp32_allreduce and dtype != torch.float32:
                coalesced = coalesced.float()
            torch.distributed.broadcast(coalesced, 0, group=comm)
            torch.cuda.synchronize()
            synced = _unflatten_dense_tensors(coalesced, datas)
            for d, s in zip(datas, synced):
                d.copy_(s)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
106

107
    def forward(self, *args, **kwargs):
Rick Ho's avatar
Rick Ho committed
108
109
110
        r'''
        Directly call the module's forward function.
        '''
111
        return self.module(*args, **kwargs)