distributed.py 1.79 KB
Newer Older
1
2
3
4
5
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
                          _take_tensors)
Kai Chen's avatar
Kai Chen committed
6
7
8
9

from .scatter_gather import scatter_kwargs


10
11
12
13
14
15
16
17
18
class MMDistributedDataParallel(nn.Module):

    def __init__(self, module, dim=0, broadcast_buffers=True):
        super(MMDistributedDataParallel, self).__init__()
        self.module = module
        self.dim = dim
        self.broadcast_buffers = broadcast_buffers

        self.broadcast_bucket_size = 32 * 1024 * 1024
19
        self._sync_params()
20
21
22
23
24
25
26
27
28

    def _dist_broadcast_coalesced(self, tensors, buffer_size):
        for tensors in _take_tensors(tensors, buffer_size):
            flat_tensors = _flatten_dense_tensors(tensors)
            dist.broadcast(flat_tensors, 0)
            for tensor, synced in zip(
                    tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
                tensor.copy_(synced)

29
    def _sync_params(self):
30
31
32
33
34
35
36
37
38
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._dist_broadcast_coalesced(module_states,
                                           self.broadcast_bucket_size)
        if self.broadcast_buffers:
            buffers = [b.data for b in self.module._all_buffers()]
            if len(buffers) > 0:
                self._dist_broadcast_coalesced(buffers,
                                               self.broadcast_bucket_size)
Kai Chen's avatar
Kai Chen committed
39
40
41

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
42
43
44
45
46

    def forward(self, *inputs, **kwargs):
        inputs, kwargs = self.scatter(inputs, kwargs,
                                      [torch.cuda.current_device()])
        return self.module(*inputs[0], **kwargs[0])