"src/vscode:/vscode.git/clone" did not exist on "fa343873a87dadbc64f08f3b247f8be7fc9f94ff"
parallel.py 6.06 KB
Newer Older
Hang Zhang's avatar
docs  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
docs  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Data Parallel"""
Hang Zhang's avatar
docs  
Hang Zhang committed
12
13
import threading
import torch
Hang Zhang's avatar
Hang Zhang committed
14
15
from torch.autograd import Function
import torch.cuda.comm as comm
Hang Zhang's avatar
sync BN  
Hang Zhang committed
16
from torch.autograd import Variable
Hang Zhang's avatar
docs  
Hang Zhang committed
17
from torch.nn.modules import Module
Hang Zhang's avatar
Hang Zhang committed
18
from torch.nn.parallel.data_parallel import DataParallel
Hang Zhang's avatar
sync BN  
Hang Zhang committed
19
from torch.nn.parallel.scatter_gather import scatter_kwargs
Hang Zhang's avatar
docs  
Hang Zhang committed
20
from torch.nn.parallel.replicate import replicate
Hang Zhang's avatar
Hang Zhang committed
21
from torch.nn.parallel.parallel_apply import get_a_var
Hang Zhang's avatar
sync BN  
Hang Zhang committed
22
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
Hang Zhang's avatar
docs  
Hang Zhang committed
23

Hang Zhang's avatar
sync BN  
Hang Zhang committed
24
__all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
25
26


Hang Zhang's avatar
Hang Zhang committed
27
def allreduce(num_inputs, *inputs):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
28
29
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
30
    """
Hang Zhang's avatar
Hang Zhang committed
31
32
    target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
    result = ReduceAddCoalesced.apply(target_gpus[0], num_inputs, *inputs)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
33
34
35
    outputs = Broadcast.apply(target_gpus, *result)
    assert len(outputs) == len(inputs)
    return outputs
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
36

Hang Zhang's avatar
docs  
Hang Zhang committed
37

Hang Zhang's avatar
Hang Zhang committed
38
39
40
41
42
43
44
45
46
47
48
49
class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, gradOutput):
        return Broadcast.apply(ctx.target_gpus, gradOutput)


class DataParallelModel(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
50
51
52
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
Hang Zhang's avatar
sync BN  
Hang Zhang committed
53
54
    splitting the input across the specified devices by chunking in the
    batch dimension.
Hang Zhang's avatar
docs  
Hang Zhang committed
55
    In the forward pass, the module is replicated on each device,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
56
57
    and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
Hang Zhang's avatar
Hang Zhang committed
58
    :class:`encoding.parallel.DataParallelCriterion`.
Hang Zhang's avatar
docs  
Hang Zhang committed
59
60
61
62
63
64
65
66
67

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
68
69
70
71
72
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

Hang Zhang's avatar
docs  
Hang Zhang committed
73
74
    Example::

Hang Zhang's avatar
Hang Zhang committed
75
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
76
        >>> y = net(x)
Hang Zhang's avatar
docs  
Hang Zhang committed
77
    """
Hang Zhang's avatar
Hang Zhang committed
78
    def gather(self, outputs, output_device):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
79
        return outputs
Hang Zhang's avatar
docs  
Hang Zhang committed
80

Hang Zhang's avatar
sync BN  
Hang Zhang committed
81

Hang Zhang's avatar
Hang Zhang committed
82
class DataParallelCriterion(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
83
    """
Hang Zhang's avatar
sync BN  
Hang Zhang committed
84
    Calculate loss in multiple-GPUs, which balance the memory usage for
Hang Zhang's avatar
docs  
Hang Zhang committed
85
86
87
    Semantic Segmentation.

    The targets are splitted across the specified devices by chunking in
Hang Zhang's avatar
Hang Zhang committed
88
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Hang Zhang's avatar
sync BN  
Hang Zhang committed
89
90
91
92
93
94
95
96

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

Hang Zhang's avatar
Hang Zhang committed
97
98
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
99
100
        >>> y = net(x)
        >>> loss = criterion(y, target)
Hang Zhang's avatar
docs  
Hang Zhang committed
101
102
103
104
    """
    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
Hang Zhang's avatar
Hang Zhang committed
105
106
107
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
        targets, kwargs = inputs(targets, kwargs, self.device_ids)
Hang Zhang's avatar
docs  
Hang Zhang committed
108
109
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
Hang Zhang's avatar
Hang Zhang committed
110
111
112
        replicas = replicate(self.module, self.device_ids[:len(inputs)])
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        return Reduce.apply(*outputs) / len(outputs)
Hang Zhang's avatar
docs  
Hang Zhang committed
113
114


Hang Zhang's avatar
Hang Zhang committed
115
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
116
117
118
119
120
121
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
Hang Zhang's avatar
Hang Zhang committed
122
123
124
125
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
126
127
128

    lock = threading.Lock()
    results = {}
Hang Zhang's avatar
Hang Zhang committed
129
    grad_enabled = torch.is_grad_enabled()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
130

Hang Zhang's avatar
Hang Zhang committed
131
132
133
134
    def _worker(i, module, input, target, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
135
136
        try:
            with torch.cuda.device_of(var_input):
Hang Zhang's avatar
Hang Zhang committed
137
                output = module(*(input + target), **kwargs)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
138
139
140
141
142
143
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

Hang Zhang's avatar
Hang Zhang committed
144
145
146
147
148
149
150
151
152
153
154
155
156
    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
                   for i, (module, input, target, kwargs) in
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
157
158
159
160
161
162
163
164
165

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs