parallel.py 7.77 KB
Newer Older
Hang Zhang's avatar
docs  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
docs  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Data Parallel"""
Hang Zhang's avatar
docs  
Hang Zhang committed
12
13
import threading
import torch
Hang Zhang's avatar
sync BN  
Hang Zhang committed
14
from torch.autograd import Variable
Hang Zhang's avatar
docs  
Hang Zhang committed
15
from torch.nn.modules import Module
Hang Zhang's avatar
sync BN  
Hang Zhang committed
16
from torch.nn.parallel.scatter_gather import scatter_kwargs
Hang Zhang's avatar
docs  
Hang Zhang committed
17
18
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply
Hang Zhang's avatar
sync BN  
Hang Zhang committed
19
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
Hang Zhang's avatar
docs  
Hang Zhang committed
20

Hang Zhang's avatar
sync BN  
Hang Zhang committed
21
__all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22
23


Hang Zhang's avatar
sync BN  
Hang Zhang committed
24
25
26
def allreduce(*inputs):
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
27
    """
Hang Zhang's avatar
sync BN  
Hang Zhang committed
28
29
30
31
32
    target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
    result = ReduceAddCoalesced.apply(target_gpus[0], 1, *inputs)
    outputs = Broadcast.apply(target_gpus, *result)
    assert len(outputs) == len(inputs)
    return outputs
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33

Hang Zhang's avatar
docs  
Hang Zhang committed
34
35
36
37
38

class ModelDataParallel(Module):
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
Hang Zhang's avatar
sync BN  
Hang Zhang committed
39
40
    splitting the input across the specified devices by chunking in the
    batch dimension.
Hang Zhang's avatar
docs  
Hang Zhang committed
41
    In the forward pass, the module is replicated on each device,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
42
43
    and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
44
    :class:`encoding.parallel.CriterionDataParallel`.
Hang Zhang's avatar
docs  
Hang Zhang committed
45
46
47
48
49
50
51
52
53

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
54
55
56
57
58
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

Hang Zhang's avatar
docs  
Hang Zhang committed
59
60
    Example::

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
61
        >>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
62
        >>> y = net(x)
Hang Zhang's avatar
docs  
Hang Zhang committed
63
    """
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
64
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
Hang Zhang's avatar
docs  
Hang Zhang committed
65
66
67
        super(ModelDataParallel, self).__init__()
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
68
69
        if output_device is None:
            output_device = device_ids[0]
Hang Zhang's avatar
docs  
Hang Zhang committed
70
71
72
        self.dim = dim
        self.module = module
        self.device_ids = device_ids
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
73
        self.output_device = output_device
Hang Zhang's avatar
docs  
Hang Zhang committed
74
75
76
77
78
79
80
81
82
83
84
        self.master_mean, self.master_var = {}, {}
        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])

    def forward(self, *inputs, **kwargs):
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, \
            self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
85
        return outputs
Hang Zhang's avatar
docs  
Hang Zhang committed
86
87
88
89
90
91
92
93
94
95

    def replicate(self, module, device_ids):
        return replicate(module, device_ids)

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def parallel_apply(self, replicas, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
96

Hang Zhang's avatar
docs  
Hang Zhang committed
97
98
class CriterionDataParallel(Module):
    """
Hang Zhang's avatar
sync BN  
Hang Zhang committed
99
    Calculate loss in multiple-GPUs, which balance the memory usage for
Hang Zhang's avatar
docs  
Hang Zhang committed
100
101
102
    Semantic Segmentation.

    The targets are splitted across the specified devices by chunking in
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
103
    the batch dimension. Please use together with :class:`encoding.parallel.ModelDataParallel`.
Hang Zhang's avatar
sync BN  
Hang Zhang committed
104
105
106
107
108
109
110
111
112
113
114
115

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

        >>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.CriterionDataParallel(criterion, device_ids=[0, 1, 2])
        >>> y = net(x)
        >>> loss = criterion(y, target)
Hang Zhang's avatar
docs  
Hang Zhang committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    """
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(CriterionDataParallel, self).__init__()
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]
        self.dim = dim
        self.module = module
        self.device_ids = device_ids
        self.output_device = output_device
        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])

    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, targets, kwargs)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
138
        return ReduceAddCoalesced.apply(self.output_device, 1, *outputs) / len(outputs)
Hang Zhang's avatar
docs  
Hang Zhang committed
139
140
141
142
143
144
145
146

    def replicate(self, module, device_ids):
        return replicate(module, device_ids)

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def parallel_apply(self, replicas, inputs, targets, kwargs):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
147
        return _criterion_parallel_apply(replicas, inputs, targets, kwargs)
Hang Zhang's avatar
docs  
Hang Zhang committed
148

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
149

Hang Zhang's avatar
sync BN  
Hang Zhang committed
150
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    # Fast track
    if len(modules) == 1:
        return (modules[0](*inputs[0], *targets[0], **kwargs_tup[0]), )

    lock = threading.Lock()
    results = {}

    def _worker(i, module, input, target, kwargs, results, lock):
        try:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
166
            var_input = _get_a_var(input)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
167
168
169
170
171
172
173
174
175
            with torch.cuda.device_of(var_input):
                output = module(input, *target, **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    threads = [threading.Thread(target=_worker,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
176
177
                                args=(i, module, input, target,
                                      kwargs, results, lock),)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
               for i, (module, input, target, kwargs) in
               enumerate(zip(modules, inputs, targets, kwargs_tup))]

    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()
    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs


Hang Zhang's avatar
sync BN  
Hang Zhang committed
194
def _get_a_var(obj):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
195
196
197
198
    if isinstance(obj, Variable):
        return obj

    if isinstance(obj, list) or isinstance(obj, tuple):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
199
        results = map(_get_a_var, obj)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
200
201
202
203
        for result in results:
            if isinstance(result, Variable):
                return result
    if isinstance(obj, dict):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
204
        results = map(_get_a_var, obj.items())
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
205
206
207
208
        for result in results:
            if isinstance(result, Variable):
                return result
    return None