parallel.py 5.93 KB
Newer Older
Hang Zhang's avatar
docs  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
docs  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Data Parallel"""
Hang Zhang's avatar
docs  
Hang Zhang committed
12
13
import threading
import torch
Hang Zhang's avatar
Hang Zhang committed
14
15
16
from torch.autograd import Function
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
Hang Zhang's avatar
docs  
Hang Zhang committed
17
from torch.nn.parallel.replicate import replicate
Hang Zhang's avatar
Hang Zhang committed
18
from torch.nn.parallel.parallel_apply import get_a_var
Hang Zhang's avatar
sync BN  
Hang Zhang committed
19
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
Hang Zhang's avatar
docs  
Hang Zhang committed
20

Hang Zhang's avatar
pylint  
Hang Zhang committed
21
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22
23


Hang Zhang's avatar
Hang Zhang committed
24
def allreduce(num_inputs, *inputs):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
25
26
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
27
    """
Hang Zhang's avatar
Hang Zhang committed
28
29
    target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
    result = ReduceAddCoalesced.apply(target_gpus[0], num_inputs, *inputs)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
30
31
32
    outputs = Broadcast.apply(target_gpus, *result)
    assert len(outputs) == len(inputs)
    return outputs
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33

Hang Zhang's avatar
docs  
Hang Zhang committed
34

Hang Zhang's avatar
Hang Zhang committed
35
36
37
38
39
40
41
42
43
44
45
46
class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, gradOutput):
        return Broadcast.apply(ctx.target_gpus, gradOutput)


class DataParallelModel(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
47
48
49
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
Hang Zhang's avatar
sync BN  
Hang Zhang committed
50
51
    splitting the input across the specified devices by chunking in the
    batch dimension.
Hang Zhang's avatar
docs  
Hang Zhang committed
52
    In the forward pass, the module is replicated on each device,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
53
54
    and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
Hang Zhang's avatar
Hang Zhang committed
55
    :class:`encoding.parallel.DataParallelCriterion`.
Hang Zhang's avatar
docs  
Hang Zhang committed
56
57
58
59
60
61
62
63
64

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
65
66
67
68
69
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

Hang Zhang's avatar
docs  
Hang Zhang committed
70
71
    Example::

Hang Zhang's avatar
Hang Zhang committed
72
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
73
        >>> y = net(x)
Hang Zhang's avatar
docs  
Hang Zhang committed
74
    """
Hang Zhang's avatar
Hang Zhang committed
75
    def gather(self, outputs, output_device):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
76
        return outputs
Hang Zhang's avatar
docs  
Hang Zhang committed
77

Hang Zhang's avatar
sync BN  
Hang Zhang committed
78

Hang Zhang's avatar
Hang Zhang committed
79
class DataParallelCriterion(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
80
    """
Hang Zhang's avatar
sync BN  
Hang Zhang committed
81
    Calculate loss in multiple-GPUs, which balance the memory usage for
Hang Zhang's avatar
docs  
Hang Zhang committed
82
83
84
    Semantic Segmentation.

    The targets are splitted across the specified devices by chunking in
Hang Zhang's avatar
Hang Zhang committed
85
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Hang Zhang's avatar
sync BN  
Hang Zhang committed
86
87
88
89
90
91
92
93

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

Hang Zhang's avatar
Hang Zhang committed
94
95
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
96
97
        >>> y = net(x)
        >>> loss = criterion(y, target)
Hang Zhang's avatar
docs  
Hang Zhang committed
98
99
100
101
    """
    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
Hang Zhang's avatar
Hang Zhang committed
102
103
104
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
        targets, kwargs = inputs(targets, kwargs, self.device_ids)
Hang Zhang's avatar
docs  
Hang Zhang committed
105
106
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
Hang Zhang's avatar
Hang Zhang committed
107
108
109
        replicas = replicate(self.module, self.device_ids[:len(inputs)])
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        return Reduce.apply(*outputs) / len(outputs)
Hang Zhang's avatar
docs  
Hang Zhang committed
110
111


Hang Zhang's avatar
Hang Zhang committed
112
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
113
114
115
116
117
118
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
Hang Zhang's avatar
Hang Zhang committed
119
120
121
122
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
123
124
125

    lock = threading.Lock()
    results = {}
Hang Zhang's avatar
Hang Zhang committed
126
    grad_enabled = torch.is_grad_enabled()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
127

Hang Zhang's avatar
Hang Zhang committed
128
129
130
131
    def _worker(i, module, input, target, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
132
133
        try:
            with torch.cuda.device_of(var_input):
Hang Zhang's avatar
Hang Zhang committed
134
                output = module(*(input + target), **kwargs)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
135
136
137
138
139
140
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

Hang Zhang's avatar
Hang Zhang committed
141
142
143
144
145
146
147
148
149
150
151
152
153
    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
                   for i, (module, input, target, kwargs) in
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
154
155
156
157
158
159
160
161

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs