"git@developer.sourcefind.cn:change/sglang.git" did not exist on "4aa1e69bc7b63ca4147e0154b3171010b09643bf"
parallel.py 7.06 KB
Newer Older
Hang Zhang's avatar
docs  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
docs  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Data Parallel"""
Hang Zhang's avatar
docs  
Hang Zhang committed
12
import threading
Zhang's avatar
Zhang committed
13
import functools
Hang Zhang's avatar
docs  
Hang Zhang committed
14
import torch
Zhang's avatar
Zhang committed
15
from torch.autograd import Variable, Function
Hang Zhang's avatar
Hang Zhang committed
16
17
18
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
Hang Zhang's avatar
sync BN  
Hang Zhang committed
19
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
Hang Zhang's avatar
docs  
Hang Zhang committed
20

Zhang's avatar
Zhang committed
21
torch_ver = torch.__version__[:3]
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22

Zhang's avatar
Zhang committed
23
24
25
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
           'patch_replication_callback']

Zhang's avatar
Zhang committed
26
def allreduce(*inputs):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
27
28
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
29
    """
Zhang's avatar
Zhang committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    return AllReduce.apply(*inputs)

class AllReduce(Function):
    @staticmethod
    def forward(ctx, num_inputs, *inputs):
        ctx.num_inputs = num_inputs
        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
        inputs = [inputs[i:i + num_inputs]
                 for i in range(0, len(inputs), num_inputs)]
        # sort before reduce sum
        inputs = sorted(inputs, key=lambda i: i[0].get_device())
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return tuple([t for tensors in outputs for t in tensors])

    @staticmethod
    def backward(ctx, *inputs):
        inputs = [i.data for i in inputs]
        inputs = [inputs[i:i + ctx.num_inputs]
                 for i in range(0, len(inputs), ctx.num_inputs)]
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
53

Hang Zhang's avatar
Hang Zhang committed
54
55
56
57
class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
Zhang's avatar
Zhang committed
58
        inputs = sorted(inputs, key=lambda i: i.get_device())
Hang Zhang's avatar
Hang Zhang committed
59
60
61
62
63
64
65
66
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, gradOutput):
        return Broadcast.apply(ctx.target_gpus, gradOutput)


class DataParallelModel(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
67
68
69
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
Hang Zhang's avatar
sync BN  
Hang Zhang committed
70
71
    splitting the input across the specified devices by chunking in the
    batch dimension.
Hang Zhang's avatar
docs  
Hang Zhang committed
72
    In the forward pass, the module is replicated on each device,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
73
74
    and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
Hang Zhang's avatar
Hang Zhang committed
75
    :class:`encoding.parallel.DataParallelCriterion`.
Hang Zhang's avatar
docs  
Hang Zhang committed
76
77
78
79
80
81
82
83
84

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
85
86
87
88
89
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

Hang Zhang's avatar
docs  
Hang Zhang committed
90
91
    Example::

Hang Zhang's avatar
Hang Zhang committed
92
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
93
        >>> y = net(x)
Hang Zhang's avatar
docs  
Hang Zhang committed
94
    """
Hang Zhang's avatar
Hang Zhang committed
95
    def gather(self, outputs, output_device):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
96
        return outputs
Hang Zhang's avatar
docs  
Hang Zhang committed
97

Zhang's avatar
Zhang committed
98
99
100
101
    def replicate(self, module, device_ids):
        modules = super(DataParallelModel, self).replicate(module, device_ids)
        return modules

Hang Zhang's avatar
sync BN  
Hang Zhang committed
102

Hang Zhang's avatar
Hang Zhang committed
103
class DataParallelCriterion(DataParallel):
Hang Zhang's avatar
docs  
Hang Zhang committed
104
    """
Hang Zhang's avatar
sync BN  
Hang Zhang committed
105
    Calculate loss in multiple-GPUs, which balance the memory usage for
Hang Zhang's avatar
docs  
Hang Zhang committed
106
107
108
    Semantic Segmentation.

    The targets are splitted across the specified devices by chunking in
Hang Zhang's avatar
Hang Zhang committed
109
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Hang Zhang's avatar
sync BN  
Hang Zhang committed
110
111
112
113
114
115
116
117

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

Hang Zhang's avatar
Hang Zhang committed
118
119
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
Hang Zhang's avatar
sync BN  
Hang Zhang committed
120
121
        >>> y = net(x)
        >>> loss = criterion(y, target)
Hang Zhang's avatar
docs  
Hang Zhang committed
122
123
124
125
    """
    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
Hang Zhang's avatar
Hang Zhang committed
126
127
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
Zhang's avatar
Zhang committed
128
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
Hang Zhang's avatar
docs  
Hang Zhang committed
129
130
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
Zhang's avatar
Zhang committed
131
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
Hang Zhang's avatar
Hang Zhang committed
132
133
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        return Reduce.apply(*outputs) / len(outputs)
Hang Zhang's avatar
docs  
Hang Zhang committed
134
135


Hang Zhang's avatar
Hang Zhang committed
136
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
137
138
139
140
141
142
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
Hang Zhang's avatar
Hang Zhang committed
143
144
145
146
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
147
148
149

    lock = threading.Lock()
    results = {}
Zhang's avatar
Zhang committed
150
151
    if torch_ver != "0.3":
        grad_enabled = torch.is_grad_enabled()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
152

Hang Zhang's avatar
Hang Zhang committed
153
    def _worker(i, module, input, target, kwargs, device=None):
Zhang's avatar
Zhang committed
154
155
        if torch_ver != "0.3":
            torch.set_grad_enabled(grad_enabled)
Hang Zhang's avatar
Hang Zhang committed
156
157
        if device is None:
            device = get_a_var(input).get_device()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
158
        try:
Zhang's avatar
Zhang committed
159
            with torch.cuda.device(device):
Hang Zhang's avatar
Hang Zhang committed
160
                output = module(*(input + target), **kwargs)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
161
162
163
164
165
166
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

Hang Zhang's avatar
Hang Zhang committed
167
168
169
170
    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
Zhang's avatar
Zhang committed
171
                   for i, (module, input, target, kwargs, device) in
Hang Zhang's avatar
Hang Zhang committed
172
173
174
175
176
177
178
179
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
180
181
182
183
184
185
186
187

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs