parallel.py 9.26 KB
Newer Older
Hang Zhang's avatar
docs  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
docs  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Data Parallel"""
Hang Zhang's avatar
docs  
Hang Zhang committed
12
import threading
13
import functools
Hang Zhang's avatar
docs  
Hang Zhang committed
14
import torch
Hang Zhang's avatar
Hang Zhang committed
15
from torch.autograd import Variable, Function
Hang Zhang's avatar
Hang Zhang committed
16
17
18
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
Hang Zhang's avatar
sync BN  
Hang Zhang committed
19
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
Hang Zhang's avatar
docs  
Hang Zhang committed
20

Hang Zhang's avatar
Hang Zhang committed
21
torch_ver = torch.__version__[:3]
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22

23
24
25
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
           'patch_replication_callback']

Hang Zhang's avatar
Hang Zhang committed
26
def allreduce(*inputs):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
27
28
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
29
    """
Hang Zhang's avatar
Hang Zhang committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    return AllReduce.apply(*inputs)

class AllReduce(Function):
    @staticmethod
    def forward(ctx, num_inputs, *inputs):
        ctx.num_inputs = num_inputs
        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
        inputs = [inputs[i:i + num_inputs]
                 for i in range(0, len(inputs), num_inputs)]
        # sort before reduce sum
        inputs = sorted(inputs, key=lambda i: i[0].get_device())
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return tuple([t for tensors in outputs for t in tensors])

    @staticmethod
    def backward(ctx, *inputs):
        inputs = [i.data for i in inputs]
        inputs = [inputs[i:i + ctx.num_inputs]
                 for i in range(0, len(inputs), ctx.num_inputs)]
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
53

Hang Zhang's avatar
docs  
Hang Zhang committed
54

Hang Zhang's avatar
Hang Zhang committed
55
56
57
58
class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
Hang Zhang's avatar
Hang Zhang committed
59
        inputs = sorted(inputs, key=lambda i: i.get_device())
Hang Zhang's avatar
Hang Zhang committed
60
61
62
63
64
65
66
67
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, gradOutput):
        return Broadcast.apply(ctx.target_gpus, gradOutput)


class DataParallelModel(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
68
69
70
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
Hang Zhang's avatar
sync BN  
Hang Zhang committed
71
72
    splitting the input across the specified devices by chunking in the
    batch dimension.
Hang Zhang's avatar
docs  
Hang Zhang committed
73
    In the forward pass, the module is replicated on each device,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
74
75
    and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
Hang Zhang's avatar
Hang Zhang committed
76
    :class:`encoding.parallel.DataParallelCriterion`.
Hang Zhang's avatar
docs  
Hang Zhang committed
77
78
79
80
81
82
83
84
85

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
86
87
88
89
90
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

Hang Zhang's avatar
docs  
Hang Zhang committed
91
92
    Example::

Hang Zhang's avatar
Hang Zhang committed
93
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
94
        >>> y = net(x)
Hang Zhang's avatar
docs  
Hang Zhang committed
95
    """
Hang Zhang's avatar
Hang Zhang committed
96
    def gather(self, outputs, output_device):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
97
        return outputs
Hang Zhang's avatar
docs  
Hang Zhang committed
98

99
100
101
102
103
    def replicate(self, module, device_ids):
        modules = super(DataParallelModel, self).replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

Hang Zhang's avatar
sync BN  
Hang Zhang committed
104

Hang Zhang's avatar
Hang Zhang committed
105
class DataParallelCriterion(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
106
    """
Hang Zhang's avatar
sync BN  
Hang Zhang committed
107
    Calculate loss in multiple-GPUs, which balance the memory usage for
Hang Zhang's avatar
docs  
Hang Zhang committed
108
109
110
    Semantic Segmentation.

    The targets are splitted across the specified devices by chunking in
Hang Zhang's avatar
Hang Zhang committed
111
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Hang Zhang's avatar
sync BN  
Hang Zhang committed
112
113
114
115
116
117
118
119

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

Hang Zhang's avatar
Hang Zhang committed
120
121
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
122
123
        >>> y = net(x)
        >>> loss = criterion(y, target)
Hang Zhang's avatar
docs  
Hang Zhang committed
124
125
126
127
    """
    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
Hang Zhang's avatar
Hang Zhang committed
128
129
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
Hang Zhang's avatar
Hang Zhang committed
130
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
Hang Zhang's avatar
docs  
Hang Zhang committed
131
132
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
Hang Zhang's avatar
Hang Zhang committed
133
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
Hang Zhang's avatar
Hang Zhang committed
134
135
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        return Reduce.apply(*outputs) / len(outputs)
Hang Zhang's avatar
Hang Zhang committed
136
        #return self.gather(outputs, self.output_device).mean()
Hang Zhang's avatar
docs  
Hang Zhang committed
137
138


Hang Zhang's avatar
Hang Zhang committed
139
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
140
141
142
143
144
145
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
Hang Zhang's avatar
Hang Zhang committed
146
147
148
149
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
150
151
152

    lock = threading.Lock()
    results = {}
Hang Zhang's avatar
Hang Zhang committed
153
154
    if torch_ver != "0.3":
        grad_enabled = torch.is_grad_enabled()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
155

Hang Zhang's avatar
Hang Zhang committed
156
    def _worker(i, module, input, target, kwargs, device=None):
Hang Zhang's avatar
Hang Zhang committed
157
158
        if torch_ver != "0.3":
            torch.set_grad_enabled(grad_enabled)
Hang Zhang's avatar
Hang Zhang committed
159
160
        if device is None:
            device = get_a_var(input).get_device()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
161
        try:
Hang Zhang's avatar
Hang Zhang committed
162
            with torch.cuda.device(device):
Hang Zhang's avatar
Hang Zhang committed
163
                output = module(*(input + target), **kwargs)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
164
165
166
167
168
169
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

Hang Zhang's avatar
Hang Zhang committed
170
171
172
173
    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
Hang Zhang's avatar
Hang Zhang committed
174
                   for i, (module, input, target, kwargs, device) in
Hang Zhang's avatar
Hang Zhang committed
175
176
177
178
179
180
181
182
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
183
184
185
186
187
188
189
190

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch

class CallbackContext(object):
    pass


def execute_replication_callbacks(modules):
    """
    Execute an replication callback `__data_parallel_replicate__` on each module created
    by original replication.

    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`

    Note that, as all modules are isomorphism, we assign each sub-module with a context
    (shared among multiple copies of this module on different devices).
    Through this context, different copies can share some information.

    We guarantee that the callback on the master copy (the first copy) will be called ahead
    of calling the callback of any slave copies.
    """
    master_copy = modules[0]
    nr_modules = len(list(master_copy.modules()))
    ctxs = [CallbackContext() for _ in range(nr_modules)]

    for i, module in enumerate(modules):
        for j, m in enumerate(module.modules()):
            if hasattr(m, '__data_parallel_replicate__'):
                m.__data_parallel_replicate__(ctxs[j], i)


def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.

    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

    data_parallel.replicate = new_replicate