gradient_accumulator.py 16.5 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from typing import Callable, Dict, Iterator, Optional, Tuple

import torch
from torch.distributed import GradBucket

import nanotron.distributed as dist
from nanotron import logging
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage

logger = logging.get_logger(__name__)


class GradientAccumulator(ABC):
    fp32_grads_allreduce_handle: Optional[torch.futures.Future]

    @abstractmethod
    def __init__(self, named_parameters: Iterator[Tuple[str, NanotronParameter]]):
        ...

    @abstractmethod
    def backward(self, loss: torch.Tensor):
        ...

    @abstractmethod
    def step(self):
        ...

    @abstractmethod
    def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.ReduceOp, reduce_scatter: bool):
        ...

    @abstractmethod
    def zero_grad(self):
        ...

    @abstractmethod
    def get_parameter_for_optimizer(self, name: str) -> NanotronParameter:
        ...

    @abstractmethod
    def get_grad_buffer(self, name: str) -> torch.Tensor:
        ...

    @abstractmethod
    def state_dict(self) -> Dict[str, torch.Tensor]:
        ...

    @abstractmethod
    def load_state_dict(self, state_dict: torch.Tensor):
        ...


class FP32GradientAccumulator(GradientAccumulator):
    def __init__(
        self,
        named_parameters: Iterator[Tuple[str, NanotronParameter]],
        grad_buckets_named_params: Optional[Iterator[Tuple[str, NanotronParameter]]] = None,
    ):
        """Create a gradient accumulator that will accumulate gradients in fp32.

        Args:
            named_parameters: The parameters that will be updated by the optimizer. In case of Zero 1, this is the parameters that will be updated in this DP rank.
            grad_buckets_named_params: The parameters to accumulate gradients for. If None it defaults to `named_parameters`. In case of Zero 1, this should be all the parameters in the model.

        Note: We use `grad_buckets_named_params` to keep grad buffers for all parameters even when Zero 1 is used. This is because we need to accumulate gradients for all parameters without having to reduce in every accumulation step.
        Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator
        """
        if grad_buckets_named_params is None:
            named_parameters = list(named_parameters)
            grad_buckets_named_params = named_parameters

        # Initialize grad bucket
        self.fp32_grad_buffers, self._contiguous_fp32_grad_buffer = self.build_grad_buffers(
            named_parameters=grad_buckets_named_params
        )

        # Assign big buffer for weights + grad in fp32
        segment_index = {}
        length = 0
        for name, param in named_parameters:
            if not param.requires_grad:
                continue

            start = length
            end_weight = start + param.numel()
            assert name not in segment_index
            segment_index[name] = (start, end_weight, param)
            length = end_weight

        big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda")
        self.parameters = {
            name: {
                "fp32": big_flat_buffer[start_weight:end_weight].view_as(param),
                "half": param,
            }
            for name, (start_weight, end_weight, param) in segment_index.items()
        }

        with torch.inference_mode():
            for _, elt in self.parameters.items():
                fp32_param = elt["fp32"]
                half_param = elt["half"]

                # Check that fp32 weights have the same memory representation as half precision weights
                assert fp32_param.stride() == half_param.stride()

                # Copy weights from half precision to full precision
                fp32_param.copy_(half_param)

                # Set requires_grad=True
                fp32_param.requires_grad = True

        self._is_accumulation_sync_step = False
        # We need the last allreduce handle to make sure it finishes before the optimizer step
        self.fp32_grads_allreduce_handle: Optional[torch.futures.Future] = None

    def assign_param_offsets(self, param_name_to_offsets: Dict[str, Dict[int, Tuple[int, int]]], dp_rank: int):
        """To use only when you use with ZeRODistributedOptimizer"""
        self.param_name_to_offsets = {
            name: elt[dp_rank] for name, elt in param_name_to_offsets.items() if dp_rank in elt
        }

    def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.ReduceOp, reduce_scatter: bool):
        if dp_pg.size() == 1:
            # They are already synced
            return

        if reduce_scatter:
            # Usually you need to run `all_reduce` in order for all gradients to be synced.
            # However when the optimizer state are sharded, you really just need to scatter to ranks that are going to run the optimizer state.
            # Effectively you replace a `all_reduce` with a `reduce_scatter` which should save an `all_gather` when using RING algorithm.
            assert hasattr(self, "param_name_to_offsets")
            named_offsets = sorted(self.param_name_to_offsets.items(), key=lambda x: x[0])
            flat_grad_buffers = [self.fp32_grad_buffers[name]["fp32_grad"].view(-1) for name, _ in named_offsets]
            dist.reduce_scatter_coalesced(
                output_tensor_list=[
                    flat_grad_buffer[start_offset:end_offset]
                    for (_, (start_offset, end_offset)), flat_grad_buffer in zip(named_offsets, flat_grad_buffers)
                ],
                input_tensor_lists=[
                    torch.split(
                        flat_grad_buffer,
                        split_size_or_sections=len(self.fp32_grad_buffers[name]["fp32_grad"].view(-1)) // dp_pg.size(),
                    )
                    for (name, _), flat_grad_buffer in zip(named_offsets, flat_grad_buffers)
                ],
                group=dp_pg,
            )
        else:
            dist.all_reduce(self._contiguous_fp32_grad_buffer, op=reduce_op, group=dp_pg)

    @staticmethod
    def build_grad_buffers(
        named_parameters: Iterator[Tuple[str, NanotronParameter]],
    ) -> Tuple[Dict[str, Dict], torch.Tensor]:
        """Builds grad buffers for all model's parameters, independently of ZeRO sharding

        Args:
            named_parameters: Parameters to build buckets for. In case of Zero1, this should be all parameters.

        Note:
            In ZeRO-1, we need to accumulate grads for all parameters, because we need to allreduce all parameters' grads across DP at each sync step.
        """
        named_parameters = [(name, param) for name, param in named_parameters if param.requires_grad]

        needed_buffer_size = sum(param.numel() for _, param in named_parameters)
        # important to have grads zeroed initially (see `self._accumulate_grad`)
        contiguous_buffer_f32_gradients = torch.zeros(needed_buffer_size, dtype=torch.float, device="cuda")
        untyped_storage = get_untyped_storage(contiguous_buffer_f32_gradients)
        element_size = contiguous_buffer_f32_gradients.element_size()

        # NOTE: Although `bias` can only exist on TP=0. It shouldn't be a problem here, because we only sync across DP
        fp32_grad_buffers = OrderedDict()  # keeps order of insertion
        offset = 0
        for name, param in named_parameters:
            if not param.requires_grad:
                continue

            assert param.dtype != torch.float, f"Expected {name} not to be float"
            assert param.is_contiguous(), f"Expected {name} to be contiguous"

            next_offset = offset + param.numel() * element_size

            fp32_grad_buffer = tensor_from_untyped_storage(
                untyped_storage=untyped_storage[offset:next_offset], dtype=torch.float
            )

            fp32_grad_buffers[name] = {
                "half": param,
                # We create sliced tensors by also slicing storage.
                # We need to specify "cuda" in order to share the same data storage, otherwise it build the tensor in "cpu" and copies over the data
                "fp32_grad": fp32_grad_buffer.view_as(param),
            }

            offset = next_offset

        return fp32_grad_buffers, contiguous_buffer_f32_gradients

    def backward(self, loss: torch.Tensor):
        result = loss.backward()

        for name, elt in self.fp32_grad_buffers.items():
            self._accumulate_grad(name=name, half_param=elt["half"])

        return result

    def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
        """Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards"""
        assert half_param.grad is not None, f"Expected param {name} to have gradient."
        fp32_grad = self.get_grad_buffer(name=name)

        if self._is_accumulation_sync_step is False:
            # WARNING: We assume fp32_grad_bucket is already zeroed
            fp32_grad.add_(half_param.grad)
            # In case _is_accumulation_sync_step = True: no need to add half gradients, because it's done in the allreduce hook

        # TODO @thomasw21: Is it better to set to zero instead?
        half_param.grad = None

        # In the case an optimizer decides to set it to None, we need to re-assign previous buffer
        if name in self.parameters:
            fp32_param = self.parameters[name]["fp32"]
            if hasattr(self, "param_name_to_offsets"):
                if name not in self.param_name_to_offsets:
                    # When `name` isn't in `param_name_to_offsets` it means the slice is empty.
                    return
                start_offset, end_offset = self.param_name_to_offsets[name]
                grad = fp32_grad.view(-1)[start_offset:end_offset]
            else:
                grad = fp32_grad
            fp32_param.grad = grad

    @contextmanager
    def no_sync(self):
        """A context manager to disable gradient synchronizations across
        data-parallel ranks.

        Note: if we use `no_sync` once, that means we're in DDP mode, and we switch the default of self._is_accumulation_sync_step to True.
        """
        old_is_accumulation_sync_step = self._is_accumulation_sync_step
        self._is_accumulation_sync_step = False
        try:
            yield
        finally:
            self._is_accumulation_sync_step = old_is_accumulation_sync_step

    @torch.inference_mode()
    def step(self):
        """Updates fp32 weights from fp32 grads.
        In case where OptimizerFromGradientAccumulator and gradient_accumulator_builder are using different parameters (e.g ZeRO).
        We need to update only the parameters that were updated by the optimizer.
        """
        for name in self.parameters.keys():
            fp32_param = self.parameters[name]["fp32"]
            half_param = self.parameters[name]["half"]
            # TODO @nouamane: should we use a fused kernel to copy?
            # Copy weights from full precision to half precision
            half_param.copy_(fp32_param)

    def zero_grad(self):
        # Full precision gradients are reset to zero/none after the underlying `optimiser.step`, so no need to reset.
        for elt in self.fp32_grad_buffers.values():
            half_param = elt["half"]

            if half_param.grad is None:
                continue

            half_param.grad = None

        # in case where self.parameters and self.fp32_grad_buffers are not the same (e.g we want to accumulate all DPs grads, and only sync at sync step)
        self._contiguous_fp32_grad_buffer.zero_()

    def get_parameter_for_optimizer(self, name: str) -> NanotronParameter:
        return self.parameters[name]["fp32"]

    def get_grad_buffer(self, name: str) -> torch.Tensor:
        """Returns the gradient of the parameter from the appropriate grad bucket."""
        return self.fp32_grad_buffers[name]["fp32_grad"]

    def state_dict(self) -> Dict[str, torch.Tensor]:
        # We consider `fp32` parameters as a state of the gradient accumulator
        return {name: elt["fp32"] for name, elt in self.parameters.items()}

    def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
        assert set(state_dict.keys()) == set(self.parameters.keys())

        with torch.inference_mode():
            for name, elt in self.parameters.items():
                elt["fp32"].copy_(state_dict[name])


@dataclasses.dataclass
class FP32GradBucketManager:
    """Manages the fp32 gradient buckets.

    Attributes:
        dp_pg: The process group to allreduce gradients across.
        accumulator: The gradient accumulator which keeps the gradient buffers.
        bucket_id_to_fp32_grad_buckets_and_dependencies: A dictionary mapping bucket ids to:
            - fp32 grad bucket (torch.Tensor)
            - set of param ids that are in the bucket -> used to know when to delete the buffer
        param_id_to_bucket_id: A dictionary mapping param ids to bucket ids."""

    dp_pg: dist.ProcessGroup
    accumulator: FP32GradientAccumulator
    param_id_to_name: Dict[int, str]

    def __post_init__(self):
        self.accumulator._is_accumulation_sync_step = True


def get_fp32_accum_hook(
    reduce_scatter: bool,
    reduce_op: dist.ReduceOp = dist.ReduceOp.AVG,
) -> Callable:
    """Returns a DDP communication hook that performs gradient accumulation in fp32.

    Args:
        reduce_op: The reduction operation to perform.
    """
    # s = torch.cuda.Stream()

    def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
        # nonlocal s
        # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation.
        # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details.
        dp_pg = state.dp_pg
        accumulator = state.accumulator
        param_id_to_name = state.param_id_to_name

        # Add new incoming gradient
        # with torch.cuda.stream(s):
        for param, grad in zip(bucket.parameters(), bucket.gradients()):
            name = param_id_to_name[id(param)]
            fp32_grad_buffer = accumulator.get_grad_buffer(name)
            fp32_grad_buffer.add_(grad.view_as(fp32_grad_buffer))

        # sync across dp
        if dp_pg.size() == 1:
            fut = torch.futures.Future()
            fut.set_result(bucket.buffer())
            return fut

        if reduce_scatter:
            assert hasattr(accumulator, "param_name_to_offsets")
            grad_buffer_tensor_list = [
                accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
            ]
            device = grad_buffer_tensor_list[0].device
            dtype = grad_buffer_tensor_list[0].dtype
            output_tensor_list = [
                grad_buffer[slice(*accumulator.param_name_to_offsets[param_id_to_name[id(param)]])]
                if param_id_to_name[id(param)] in accumulator.param_name_to_offsets
                else torch.empty(0, dtype=dtype, device=device)
                for grad_buffer, param in zip(grad_buffer_tensor_list, bucket.parameters())
            ]
            input_tensor_lists = [
                torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // dp_pg.size())
                for grad_buffer in grad_buffer_tensor_list
            ]
            dist.reduce_scatter_coalesced(
                output_tensor_list=output_tensor_list,
                input_tensor_lists=input_tensor_lists,
                op=reduce_op,
                group=dp_pg,
                async_op=True,
            )
        else:
            grad_buffer_tensor_list = [
                accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
            ]
            accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced(
                grad_buffer_tensor_list, group=dp_pg, async_op=True, op=reduce_op
            )
            # we shouldn't wait for this future for the rest of the backward

        # with torch.cuda.stream(s):
        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
        half_grad_bucket = bucket.buffer()
        fut.set_result(half_grad_bucket)
        return fut  # We don't care about the new half grad values, so we return the old ones

    return fp32_accum_hook