distributed.py 11.2 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16
17
from abc import ABC
from abc import abstractmethod
18
19
20
# >>>
import math
# <<<
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
21

Raul Puri's avatar
Raul Puri committed
22
23
24
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
from megatron import get_args
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
26
from megatron import mpu
27
from .module import MegatronModule
Raul Puri's avatar
Raul Puri committed
28

29

Raul Puri's avatar
Raul Puri committed
30

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
31
32
class MemoryBuffer:

33
34
    # >>>
    def __init__(self, numel, numel_padded, dtype):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
35
        self.numel = numel
36
        self.numel_padded = numel_padded
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
37
        self.dtype = dtype
38
        self.data = torch.zeros(self.numel_padded,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
39
40
41
                                dtype=self.dtype,
                                device=torch.cuda.current_device(),
                                requires_grad=False)
42
    # <<<
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    def zero(self):
        """Reset the buffer to zero."""
        self.data.zero_()


    def get(self, shape, start_index):
        """Return a tensor with the input `shape` as a view into the
        1-D data starting at `start_index`."""
        end_index = start_index + shape.numel()
        assert end_index <= self.numel, \
            'requested tensor is out of the buffer range.'
        buffer_tensor = self.data[start_index:end_index]
        buffer_tensor = buffer_tensor.view(shape)
        return buffer_tensor


Raul Puri's avatar
Raul Puri committed
60

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
61
62
63
64
65
66
class DistributedDataParallelBase(MegatronModule, ABC):
    """Abstract class for DDP."""

    def __init__(self, module):
        super(DistributedDataParallelBase, self).__init__()
        # Keep a pointer to the model.
Raul Puri's avatar
Raul Puri committed
67
        self.module = module
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
68
69
70
71
72
73


    @abstractmethod
    def allreduce_gradients(self):
        pass

Raul Puri's avatar
Raul Puri committed
74
75
76
77

    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
78

Raul Puri's avatar
Raul Puri committed
79
    def state_dict(self, destination=None, prefix='', keep_vars=False):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
80
        return self.module.state_dict(destination, prefix, keep_vars)
Raul Puri's avatar
Raul Puri committed
81
82


83
84
85
86
87
    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(destination, prefix,
                                                          keep_vars)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88

Raul Puri's avatar
Raul Puri committed
89
90
91
    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


class DistributedDataParallel(DistributedDataParallelBase):
    """DDP with contiguous buffers options to storre and accumulate gradients.
    This class:
        - has the potential to reduce memory fragmentation.
        - provides the option to do the gradient accumulation
          in a type other than the params type (for example fp32)

    Arguments:
        module: input model.
        accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
            and the gradient all-reduce all in in float32. If this option is
            true, we require `use_contiguous_buffers` to be true too.
        use_contiguous_buffers: if true, use a contiguous buffer to store the
            gradients.
    """

    def __init__(self, module,
                 accumulate_allreduce_grads_in_fp32,
                 use_contiguous_buffers):

        super(DistributedDataParallel, self).__init__(module)

        self.accumulate_allreduce_grads_in_fp32 \
            = accumulate_allreduce_grads_in_fp32
        self.use_contiguous_buffers = use_contiguous_buffers
        # If we are using fp32-accumulate-allreduce explicitly
        # this means we need main grads in a continous buffer.
        if self.accumulate_allreduce_grads_in_fp32:
            assert self.use_contiguous_buffers

        # ===================================
        # Rest of this part applies only to
        # the case we use continuous buffers.
        # ===================================
        self._grad_buffers = None
129
        self._grad_buffer_param_index_map = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
130
131
        if self.use_contiguous_buffers:
            self._grad_buffers = {}
132
            self._grad_buffer_param_index_map = {}
133
            data_parallel_world_size = mpu.get_data_parallel_world_size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

            # Simple function to define buffer type.
            def _get_buffer_type(param):
                return torch.float if \
                    self.accumulate_allreduce_grads_in_fp32 else param.dtype

            # First calculate total number of elements per type.
            type_num_elements = {}
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
                                               + param.data.nelement()

            # Allocate the buffer.
            for dtype, num_elements in type_num_elements.items():
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

                # >>>
                # If using distributed optimizer, pad memory buffer to be
                # multiple of data_parallel_world_size. (This padding is done
                # due to a constraint with the reduce_scatter op, which requires
                # all tensors have equal size. See: optimizer.py.)
                num_elements_padded = data_parallel_world_size * \
                    int(math.ceil(num_elements / data_parallel_world_size))
                # <<<

                # Allocate grad buffer.
                self._grad_buffers[dtype] = MemoryBuffer(num_elements,
                                                         num_elements_padded,
                                                         dtype)
                # >>>
                # from lutil import pax
                # if True or num_elements % data_parallel_world_size != 0:
                #     pax(0, {
                #         "data_parallel_world_size" : data_parallel_world_size,
                #         "num_elements" : num_elements,
                #         "num_elements_padded" : num_elements_padded,
                #         "modulo" : num_elements % data_parallel_world_size,
                #         "grad buffer" : self._grad_buffers[dtype],
                #     })
                # <<<
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175
176
177
178
179
180
181
182
183

            # Assume the back prop order is reverse the params order,
            # store the start index for the gradients.
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] -= param.data.nelement()
                    param.main_grad = self._grad_buffers[dtype].get(
                        param.data.shape, type_num_elements[dtype])
184
185
186
187
188
                    # >>>
                    # self._grad_buffer_param_offsets[dtype][param] = \
                    #     type_num_elements[dtype]
                    if dtype not in self._grad_buffer_param_index_map:
                        self._grad_buffer_param_index_map[dtype] = {}
189
190
191
192
193
194
195
196
                    # self._grad_buffer_param_index_map[dtype][param] = {
                    #     "start" : type_num_elements[dtype],
                    #     "end" : type_num_elements[dtype] + param.data.nelement(),
                    # }
                    self._grad_buffer_param_index_map[dtype][param] = (
                        type_num_elements[dtype],
                        type_num_elements[dtype] + param.data.nelement(),
                    )
197
                    # <<<
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

            # Backward hook.
            # Accumalation function for the gradients. We need
            # to store them so they don't go out of scope.
            self.grad_accs = []
            # Loop over all the parameters in the model.
            for param in self.module.parameters():
                if param.requires_grad:
                    # Expand so we get access to grad_fn.
                    param_tmp = param.expand_as(param)
                    # Get the gradient accumulator functtion.
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_param_hook(param))
                    self.grad_accs.append(grad_acc)

213
214
215
216
217
218
219
220
221
222
223
        # >>>
        # from lutil import pax, tp
        # pax(0, {
        #     "_grad_buffers" : {k:b.numel for k,b in self._grad_buffers.items()},
        #     "_grad_buffer_param_offsets" : self._grad_buffer_param_offsets,
        #     **{"_grad_buffer_param_offsets / %s" % ty : {
        #         str(p.shape) : o for p, o in po.items()
        #     } for ty, po in self._grad_buffer_param_offsets.items()},
        # })
        # <<<

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    def _make_param_hook(self, param):
        """Create the all-reduce hook for backprop."""
        # Hook used for back-prop.
        def param_hook(*unused):
            # Add the gradient to the buffer.
            if param.grad.data is not None:
                param.main_grad.add_(param.grad.data)
                # Now we can deallocate grad memory.
                param.grad = None
        return param_hook


    def zero_grad_buffer(self):
        """Set the grad buffer data to zero. Needs to be called at the
        begining of each iteration."""
        assert self._grad_buffers is not None, 'buffers are not initialized.'
        for _, buffer_ in self._grad_buffers.items():
            buffer_.zero()


245
246
247
248
249
250
251
    def broadcast_params(self):
        for param in self.module.parameters():
            torch.distributed.broadcast(param.data,
                                        src=mpu.get_data_parallel_src_rank(),
                                        group=mpu.get_data_parallel_group())


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def allreduce_gradients(self):
        """Reduce gradients across data parallel ranks."""
        # If we have buffers, simply reduce the data in the buffer.
        if self._grad_buffers is not None:
            for _, buffer_ in self._grad_buffers.items():
                buffer_.data /= mpu.get_data_parallel_world_size()
                torch.distributed.all_reduce(
                    buffer_.data, group=mpu.get_data_parallel_group())
        else:
            # Otherwise, bucketize and all-reduce
            buckets = {}
            # Pack the buckets.
            for param in self.module.parameters():
                if param.requires_grad and param.grad is not None:
                    tp = param.data.type()
                    if tp not in buckets:
                        buckets[tp] = []
                    buckets[tp].append(param)
                    param.main_grad = param.grad

            # For each bucket, all-reduce and copy all-reduced grads.
            for tp in buckets:
                bucket = buckets[tp]
                grads = [param.grad.data for param in bucket]
                coalesced = _flatten_dense_tensors(grads)
                coalesced /= mpu.get_data_parallel_world_size()
                torch.distributed.all_reduce(
                    coalesced, group=mpu.get_data_parallel_group())
                for buf, synced in zip(grads, _unflatten_dense_tensors(
                        coalesced, grads)):
                    buf.copy_(synced)