"docs/EN/vscode:/vscode.git/clone" did not exist on "5136abf35abf106bce27e4692512e269e071b451"
distributed.py 9.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
4
from abc import ABC
from abc import abstractmethod
5
import math
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
6

Raul Puri's avatar
Raul Puri committed
7
8
9
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
10
from megatron import get_args
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
11
from megatron import mpu
12
from .module import MegatronModule
Raul Puri's avatar
Raul Puri committed
13

14

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
15
16
class MemoryBuffer:

17
    def __init__(self, numel, numel_padded, dtype):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
18
        self.numel = numel
19
        self.numel_padded = numel_padded
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
20
        self.dtype = dtype
21
        self.data = torch.zeros(self.numel_padded,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
                                dtype=self.dtype,
                                device=torch.cuda.current_device(),
                                requires_grad=False)

    def zero(self):
        """Reset the buffer to zero."""
        self.data.zero_()


    def get(self, shape, start_index):
        """Return a tensor with the input `shape` as a view into the
        1-D data starting at `start_index`."""
        end_index = start_index + shape.numel()
        assert end_index <= self.numel, \
            'requested tensor is out of the buffer range.'
        buffer_tensor = self.data[start_index:end_index]
        buffer_tensor = buffer_tensor.view(shape)
        return buffer_tensor


Raul Puri's avatar
Raul Puri committed
42

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
43
44
45
46
47
48
class DistributedDataParallelBase(MegatronModule, ABC):
    """Abstract class for DDP."""

    def __init__(self, module):
        super(DistributedDataParallelBase, self).__init__()
        # Keep a pointer to the model.
Raul Puri's avatar
Raul Puri committed
49
        self.module = module
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
50
51
52
53
54
55


    @abstractmethod
    def allreduce_gradients(self):
        pass

Raul Puri's avatar
Raul Puri committed
56
57
58
59

    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
60

61
62
    def state_dict(self, prefix='', keep_vars=False):
        return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
Raul Puri's avatar
Raul Puri committed
63
64


65
66
67
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(prefix=prefix,
                                                          keep_vars=keep_vars)
68

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
69

Raul Puri's avatar
Raul Puri committed
70
71
72
    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


class DistributedDataParallel(DistributedDataParallelBase):
    """DDP with contiguous buffers options to storre and accumulate gradients.
    This class:
        - has the potential to reduce memory fragmentation.
        - provides the option to do the gradient accumulation
          in a type other than the params type (for example fp32)

    Arguments:
        module: input model.
        accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
            and the gradient all-reduce all in in float32. If this option is
            true, we require `use_contiguous_buffers` to be true too.
        use_contiguous_buffers: if true, use a contiguous buffer to store the
            gradients.
    """

    def __init__(self, module,
                 accumulate_allreduce_grads_in_fp32,
                 use_contiguous_buffers):

        super(DistributedDataParallel, self).__init__(module)

        self.accumulate_allreduce_grads_in_fp32 \
            = accumulate_allreduce_grads_in_fp32
        self.use_contiguous_buffers = use_contiguous_buffers
        # If we are using fp32-accumulate-allreduce explicitly
        # this means we need main grads in a continous buffer.
        if self.accumulate_allreduce_grads_in_fp32:
            assert self.use_contiguous_buffers

        # ===================================
        # Rest of this part applies only to
        # the case we use continuous buffers.
        # ===================================
        self._grad_buffers = None
110
        self._grad_buffer_param_index_map = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
111
112
        if self.use_contiguous_buffers:
            self._grad_buffers = {}
113
            self._grad_buffer_param_index_map = {}
114
            data_parallel_world_size = mpu.get_data_parallel_world_size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

            # Simple function to define buffer type.
            def _get_buffer_type(param):
                return torch.float if \
                    self.accumulate_allreduce_grads_in_fp32 else param.dtype

            # First calculate total number of elements per type.
            type_num_elements = {}
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
                                               + param.data.nelement()

            # Allocate the buffer.
            for dtype, num_elements in type_num_elements.items():
131
132
133
134
135
136
137
138
139
140
141
142

                # If using distributed optimizer, pad memory buffer to be
                # multiple of data_parallel_world_size. (This padding is done
                # due to a constraint with the reduce_scatter op, which requires
                # all tensors have equal size. See: optimizer.py.)
                num_elements_padded = data_parallel_world_size * \
                    int(math.ceil(num_elements / data_parallel_world_size))

                # Allocate grad buffer.
                self._grad_buffers[dtype] = MemoryBuffer(num_elements,
                                                         num_elements_padded,
                                                         dtype)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143
144
145
146
147
148
149
150
151

            # Assume the back prop order is reverse the params order,
            # store the start index for the gradients.
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] -= param.data.nelement()
                    param.main_grad = self._grad_buffers[dtype].get(
                        param.data.shape, type_num_elements[dtype])
152
153
                    if dtype not in self._grad_buffer_param_index_map:
                        self._grad_buffer_param_index_map[dtype] = {}
154
155
156
157
                    self._grad_buffer_param_index_map[dtype][param] = (
                        type_num_elements[dtype],
                        type_num_elements[dtype] + param.data.nelement(),
                    )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

            # Backward hook.
            # Accumalation function for the gradients. We need
            # to store them so they don't go out of scope.
            self.grad_accs = []
            # Loop over all the parameters in the model.
            for param in self.module.parameters():
                if param.requires_grad:
                    # Expand so we get access to grad_fn.
                    param_tmp = param.expand_as(param)
                    # Get the gradient accumulator functtion.
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_param_hook(param))
                    self.grad_accs.append(grad_acc)


    def _make_param_hook(self, param):
        """Create the all-reduce hook for backprop."""
        # Hook used for back-prop.
        def param_hook(*unused):
            # Add the gradient to the buffer.
179
180
            if param.grad is not None:
                # The gradient function of linear layers is fused with GEMMs
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
                param.main_grad.add_(param.grad.data)
                # Now we can deallocate grad memory.
                param.grad = None
        return param_hook


    def zero_grad_buffer(self):
        """Set the grad buffer data to zero. Needs to be called at the
        begining of each iteration."""
        assert self._grad_buffers is not None, 'buffers are not initialized.'
        for _, buffer_ in self._grad_buffers.items():
            buffer_.zero()


195
196
197
198
199
200
201
    def broadcast_params(self):
        for param in self.module.parameters():
            torch.distributed.broadcast(param.data,
                                        src=mpu.get_data_parallel_src_rank(),
                                        group=mpu.get_data_parallel_group())


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    def allreduce_gradients(self):
        """Reduce gradients across data parallel ranks."""
        # If we have buffers, simply reduce the data in the buffer.
        if self._grad_buffers is not None:
            for _, buffer_ in self._grad_buffers.items():
                buffer_.data /= mpu.get_data_parallel_world_size()
                torch.distributed.all_reduce(
                    buffer_.data, group=mpu.get_data_parallel_group())
        else:
            # Otherwise, bucketize and all-reduce
            buckets = {}
            # Pack the buckets.
            for param in self.module.parameters():
                if param.requires_grad and param.grad is not None:
                    tp = param.data.type()
                    if tp not in buckets:
                        buckets[tp] = []
                    buckets[tp].append(param)
                    param.main_grad = param.grad

            # For each bucket, all-reduce and copy all-reduced grads.
            for tp in buckets:
                bucket = buckets[tp]
                grads = [param.grad.data for param in bucket]
                coalesced = _flatten_dense_tensors(grads)
                coalesced /= mpu.get_data_parallel_world_size()
                torch.distributed.all_reduce(
                    coalesced, group=mpu.get_data_parallel_group())
                for buf, synced in zip(grads, _unflatten_dense_tensors(
                        coalesced, grads)):
                    buf.copy_(synced)