sharded_ddp.py 29.3 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
8
A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted gradient
reduction automatically.
9
10
"""

11
from collections import deque
12
import contextlib
13
import functools
14
15
from itertools import chain
import logging
16
from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
17
18

import torch
19
from torch import nn
20
from torch.autograd import Variable
21
import torch.autograd.profiler as profiler
22
23
import torch.distributed as dist

24
from fairscale.nn.misc import GradBucket
25
from fairscale.optim import OSS
26
from fairscale.optim.utils import Workhandle, get_global_rank
27
28


29
30
31
32
def _trainable(param: torch.Tensor) -> bool:
    return param.requires_grad


33
class ShardedDataParallel(nn.Module):
34
    """ Wrap the model, and reduce the gradients to the right rank during the backward pass.
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    - the partition is given by the sharded optimizer
    - wrap the base model with a model which knows where to reduce each gradient
    - add an autograd function which calls the model grad dispatch on the way back

     Args:
        module (nn.Module):
            model to be wrapped
        sharded_optimizer (OSS, or list of OSS):
            the sharded optimizer(s) which will decide the gradient partitioning

    Keyword Args:
        process_group (group):
            torch.distributed group (default: group.WORLD)
        broadcast_buffers (bool):
            Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass.
            Same setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.
        sync_models_at_startup (bool):
            Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
            or the training restarts from a saved state
55
        reduce_buffer_size (int):
56
            The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused).
57
            this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
58
            Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
59
60
61
62
        auto_refresh_trainable (bool):
            (default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
            and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
            a parameter is frozen or unfrozen.
63
64
65
        reduce_fp16 (bool):
            cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve performance
            for multi node jobs using PyTorch AMP. The effect is similar to DDP's fp16_compress_hook_ and will also save some memory.
66

67
    .. _fp16_compress_hook: https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html?highlight=fp16#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
68
69
70
71
72
73

    .. warning:
        ShardedDDP implements gradient sharding, meaning that each rank only owns a unique shard of the model gradients
        after the backward pass, in order to save memory and some communication bandwidth.

    .. warning:
74
75
76
        As a consequence of sharding:
            * in case of gradient clipping, one has to use the `clip_grad_norm` exposed by
                the `optimizer state sharding wrapper <fairscale.optim.OSS>`
77

78
79
80
81
82
            * after loss.backward() (or equivalent) each rank will have `None` in place of some param.grad

            * Pytorch and Apex AMP implementations will hang when used in conjunction with `ShardedDDP`.
                One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`,
                compatible with PytorchAMP.
83

84
85
86
87
88
89
90
    .. warning:
        If `auto_refresh_trainable` is set to `True` (this is the default) then any trainability change in the model graph will be handled
        automatically.
        If `auto_refresh_trainable` is set to `False`, ShardedDDP will not refresh its assumptions with respect to trainable parameters
        for every forward pass, in the hope of saving some time. If some parameters are frozen or unfrozen over time, please refresh
        ShardedDDP assumptions by calling `refresh_trainable()` just after said change (before the next forward pass).

91
92
93
    """

    def __init__(
94
95
        self,
        module: nn.Module,
96
        sharded_optimizer: Union[OSS, List[OSS]],
97
        process_group: Any = None,
98
99
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
100
        reduce_buffer_size: int = 2 ** 23,
101
        auto_refresh_trainable: bool = True,
102
        reduce_fp16: bool = False,
103
104
105
    ):
        super().__init__()

106
107
108
        # This field needs to be exposed to insure interface parity with DDP
        self.module = module

109
110
111
112
        self._sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
        self._enable_broadcast_buffers = broadcast_buffers
        self._auto_refresh_trainable = auto_refresh_trainable
        self._reduce_fp16 = reduce_fp16
113
        if reduce_buffer_size > 0 and reduce_fp16:
114
            self._reduce_fp16 = False
115
116
117
            logging.warning(
                "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
            )
118
119
120

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
121
122
        self._should_accumulate_grads = False
        self._accumulate_grads_flipped = False
123
124

        # Communication related attributes
125
126
127
128
129
130
        self._process_group = process_group if process_group is not None else dist.group.WORLD
        self._backend = dist.get_backend(self._process_group)
        self._world_size_scaling = 1.0 / dist.get_world_size(self._process_group)  # > 0
        self._reference_global_rank = get_global_rank(self._process_group, 0)  # picking rank 0 as the reference
        self._rank = dist.get_rank(self._process_group)
        self._global_rank = get_global_rank(self._process_group, self._rank)
131
        self._local_to_global_rank = [
132
            get_global_rank(self._process_group, i) for i in range(dist.get_world_size(self._process_group))
133
        ]
134
135
136
137

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
138
        devices = {p.device for p in self.module.parameters()}
139
140
        self.is_multi_device_module = len(devices) > 1

141
        distinct_device_types = {p.device.type for p in self.module.parameters()}
142
143
144
145
146
147
148
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
149
150
151
        # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks

        # - we build an iterator which goes through all the parameters involved globally
152
153
        self._all_params = list(
            chain(
154
155
156
157
                *[
                    sum([sum(p, []) for p in optim._per_device_params.values()], [])
                    for optim in self._sharded_optimizers
                ]
158
            )
159
        )
160
161
162
163
        self._trainable_params: List[torch.Tensor] = []
        self._grad_to_be_reduced: List[bool] = []
        self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
        self._reference_trainable_mask = list(map(_trainable, self._all_params))
164

165
        # - setup buckets and tensor views
166
        model_size = sum([p.numel() for p in self.module.parameters()])
167
        self._buffer_max_size = min(reduce_buffer_size, model_size)
168

169
170
        if dist.get_world_size(self._process_group) == 1:
            self._buffer_max_size = 0
171
172
            logging.info("Training is not really distributed, single rank. Deactivating buckets")

173
174
        logging.info(
            "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
175
                self._buffer_max_size / 2 ** 20, model_size / 2 ** 20
176
177
            )
        )
178
        self._use_buckets = self._buffer_max_size > 0
179

180
        self._buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
181
        self._should_bucket_grad: List[bool] = []
182
        self._bucket_list: List[GradBucket] = []
183

184
        # - setup backward hooks which will be called by Torch's autograd in due time
185
        self._grad_accs: List[Callable] = []
186
        self._grad_hooks: List[Any] = []
187
        self._manual_reduce: List[Callable] = []
188

189
        # passing a handle to torch.nn.SyncBatchNorm layer
190
        self._passing_sync_batchnorm_handle(self.module)
191

192
193
194
195
        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()

196
        self._work_handles: Deque[Workhandle] = deque()
197
        self._bucket_flush_callback_set = False
198

199
200
201
202
203
    def forward(self, *inputs: Any, **kwargs: Any) -> Any:
        """
        Module forward pass, handles any DDP-specific work in the background. Primes the
        backward pass for gradient reduction to the proper ranks.
        """
204

205
206
207
        with profiler.record_function("fairscale::sdp::forward"):
            # Deferred initialization, or change detection
            needs_setup = len(self._grad_hooks) == 0 and self.training
208

209
210
            if self._auto_refresh_trainable:
                needs_setup |= self._detect_train_change()
211

212
213
            if needs_setup:
                self.refresh_trainable()
214

215
216
217
218
            if self._enable_broadcast_buffers:
                # NCCL communications are on a different stream, needs to be blocking
                # for the subsequent FW to be correct
                self.sync_buffers(blocking=True)
219

220
221
            # Reset all the grad reduce and bucket state flags
            self._clear_counters()
222

223
            # Normal FW on the base model
224
            return self.module(*inputs, **kwargs)
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    def to(  # type: ignore
        self,
        device: Optional[Union[int, torch.device]],
        dtype: Optional[torch.dtype] = None,
        non_blocking: bool = False,
    ) -> "ShardedDataParallel":
        """
        Moves and/or casts the parameters and buffers.

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
        floating point desired :attr:`dtype` s. In addition, this method will
        only cast the floating point parameters and buffers to :attr:`dtype`
        (if given). The integral parameters and buffers will be moved
        :attr:`device`, if that is given, but with dtypes unchanged. When
        :attr:`non_blocking` is set, it tries to convert/move asynchronously
        with respect to the host if possible, e.g., moving CPU Tensors with
        pinned memory to CUDA devices.

        .. note::
            This method modifies the module in-place.

247
248
249
250
251
        .. warning:
            Device changes are not supported, and this will raise an exception. The issue in that case is not
            really ShardedDDP, but OSS which will not be aware of the device change, and whose buffers will be
            in a broken state.

252
253
254
255
256
257
258
259
260
        Arguments:
            device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
            dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
            non_blocking (bool): make it an asynchronous call.

        Returns:
            Module: self.
        """

261
        assert (
262
263
264
265
            len(self._buckets.keys()) == 0 or device in self._buckets.keys()
        ), "Changing devices is not supported, because this would break OSSs state"
        assert (
            len(self._buckets.keys()) < 2
266
267
        ), "Several devices specified to begin with, incompatible with setting a single device here"

268
269
        for _device in self._buckets.keys():
            for bucket in self._buckets[_device].values():
270
                bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
271

272
        self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
273

274
275
276
277
    def refresh_trainable(self) -> None:
        """ If the module trainability has changed, update all the assumptions """

        # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
278
279
280
281
        if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False):
            logging.warning(
                "Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
            )
282

283
284
285
        with profiler.record_function("fairscale::sdp::refresh_trainable"):
            self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
            self._trainable_params.sort(key=lambda x: x.numel())
286

287
288
289
290
            self._trainable_param_to_rank = {}
            for optim in self._sharded_optimizers:
                # OSS may need to change the communication pattern
                optim.refresh_trainable()
291

292
293
294
295
296
297
298
                # Update ShardedDDP given the new partitions
                for (
                    device_per_rank_params
                ) in optim._per_device_params.values():  # all the params on this device (inc all ranks)
                    for device_params in device_per_rank_params:
                        for param in filter(lambda x: x.requires_grad, device_params):
                            self._trainable_param_to_rank[param] = optim._param_to_rank[param]
299

300
301
            self._setup_bucket_strategy()
            self._setup_backward_hooks()
302

303
304
    def reduce(self) -> None:
        """
305
306
307
308
309
310
311
312
313
314
        This does not *need* to be called, the gradient reduction is done automatically during the BW pass.
        Use this method to reduce the gradients manually
        """

        # Check that this is not a mistake, if there's nothing to reduce
        assert functools.reduce(
            lambda x, y: x or y, self._grad_to_be_reduced, False
        ), "No grads waiting to be reduced, maybe that this was called twice or there was no BW pass ?"

        # Trigger all the current BW hooks
315
316
        self._bucket_flush_callback_set = True  # no need to flush in the end, we own the callback execution
        _ = list(map(lambda x: x(), self._manual_reduce))
317
318
319

        # Make sure that all the futures are consumed
        self._consume_work_handles()
320

321
    @torch.no_grad()
322
323
324
    def sync_buffers(self, blocking: bool = False) -> None:
        """
        Sync all the param buffers in between ranks (including for instance batch norm statistics).
325
326
327

        Arguments:
            blocking (bool): wait for the operation to conclude.
328
        """
329

330
331
        with profiler.record_function("fairscale::sdp::sync_buffers"):
            work_handles = []
332

333
            for buffer in self.module.buffers(recurse=True):
334
335
336
                work_handles.append(
                    dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
                )
337

338
339
340
341
342
            if blocking and work_handles:
                if self._backend != dist.Backend.NCCL:
                    _ = list(filter(lambda x: x.wait(), work_handles))
                else:
                    work_handles[-1].wait()
343

344
345
346
347
348
349
350
351
352
    def zero_grad(self, set_to_none: bool = False) -> None:
        r"""Sets gradients of all model parameters to zero. See similar function
        under :class:`torch.optim.Optimizer` for more context.

        Arguments:
            set_to_none (bool): instead of setting to zero, set the grads to None.
                See :meth:`torch.optim.Optimizer.zero_grad` for details.
        """

353
        for index, trainable_param in enumerate(self._all_params):
354
355
356
357
358
            if set_to_none and not self._should_bucket_grad[index]:
                trainable_param.grad = None
            elif trainable_param.grad is not None:
                trainable_param.grad.zero_()

359
360
361
        for bucket in self._bucket_list:
            bucket.zero()

362
363
364
365
366
    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
367
            return getattr(self.module, name)
368

369
    @contextlib.contextmanager
370
371
    def no_sync(self) -> Generator:
        """A context manager to disable gradient synchronization."""
372
373
        old_should_accumulate_grads = self._should_accumulate_grads
        self._should_accumulate_grads = True
374
        yield
375
376
        self._accumulate_grads_flipped = self._should_accumulate_grads != old_should_accumulate_grads
        self._should_accumulate_grads = old_should_accumulate_grads
377

378
    @torch.no_grad()
379
    def _clear_counters(self) -> None:
380
        """Reset all the grad reduce and call counters"""
381
382
        if self.training:
            self._grad_to_be_reduced = [True for _ in self._trainable_params]
383
        self._bucket_flush_callback_set = False
384

385
        if self._use_buckets:
386
            for bucket in self._bucket_list:
387
                bucket.reset_checked_in()
388

389
390
        if not self._should_accumulate_grads:
            self._accumulate_grads_flipped = False
391

392
    def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
393
        """
394
395
396
397
        Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
        or contribute to a bucket and reduce when the bucket is full.

        Either way a delayed action is necessary and is passed as a callback.
398
399
        """

400
        if not self._use_buckets or not self._should_bucket_grad[index]:
401
402
403
404
            # Direct reduction
            @torch.no_grad()
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
405
                if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
406
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
407

408
                    if not self._bucket_flush_callback_set:
409
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
410
411
                        self._bucket_flush_callback_set = True

412
413
                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
414
                    param.grad.mul_(self._world_size_scaling)
415

416
                    if self._reduce_fp16:
417
418
                        param.grad.data = param.grad.data.half()

419
420
                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
421
                        if dst_rank != self._global_rank:
422
                            param.grad = None
423
424
425
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(dtype=param.dtype)
426
427

                    # Async reduce for this buffer, log the future
428
                    self._work_handles.append(
429
430
                        Workhandle(
                            handle=dist.reduce(
431
432
                                tensor=param.grad.data,
                                dst=self._local_to_global_rank[dst_rank],
433
                                group=self._process_group,
434
                                async_op=True,
435
436
437
                            ),
                            callback=cleanup,
                        )
438
                    )
439

440
                    # Opportunistically try to empty the queue, free memory
441
442
443
444
445
446
447
                    self._try_consume_work_handle()

        else:

            @torch.no_grad()
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
448

449
                if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
450
451
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

452
                    if not self._bucket_flush_callback_set:
453
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
454
455
                        self._bucket_flush_callback_set = True

456
457
                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
458
                    bucket = self._buckets[param.device][dst_rank]
459
460
                    bucket.params_checked_in += 1

461
462
463
                    if bucket.all_checked_in:
                        assert bucket.buffer is not None

464
                        # Normalize the bucket in one go
465
                        bucket.buffer.mul_(self._world_size_scaling)
466
467
468

                        # Reduce the bucket
                        bucket.sent = True
469
                        self._work_handles.append(
470
471
                            Workhandle(
                                handle=dist.reduce(
472
473
                                    tensor=bucket.buffer,
                                    dst=bucket.destination,
474
                                    group=self._process_group,
475
                                    async_op=True,
476
477
478
479
                                ),
                                callback=None,
                            )
                        )
480

481
482
                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()
483

484
        return reduce
485
486

    def _setup_backward_hooks(self) -> None:
487
        """
488
489
        Attach a reduce function to each grad-requiring parameter.
        This makes the gradient reduction automatic whenever there's a backward pass
490
        """
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        with profiler.record_function("fairscale::sdp::setup_backward_hooks"):
            # Detach possible pre-existing hooks
            while len(self._grad_hooks) > 0:
                self._grad_hooks.pop().remove()

            # Go through the parameters, attach the hook
            self._grad_accs = []
            self._manual_reduce = []
            if not self.training:
                return

            for index, param in enumerate(self._trainable_params):
                if param.grad is not None and param.grad.requires_grad:
                    raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")

                p_tmp = param.expand_as(param)

                # See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
                # We're interested in the tensors which will be tracked by Autograd
                # Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
                # these do not need to be sync'ed
                if p_tmp.grad_fn is not None:
                    # Register the hook to the next function in line,
                    # so that the hook is fired when this grad has properly been computed
                    # (by default the hook with Pytorch is a pre-grad, not a post-grad)
                    grad_acc = p_tmp.grad_fn.next_functions[0][0]
                    dst_rank = self._trainable_param_to_rank[param]

                    reduce_function = self._get_reduce_fn(index, param, dst_rank)

                    self._grad_hooks.append(grad_acc.register_hook(reduce_function))
                    self._grad_accs.append(grad_acc)  # keep this hook in scope
                    self._manual_reduce.append(reduce_function)
524

525
    @torch.no_grad()
526
527
528
529
530
    def _sync_params_and_buffers(self) -> None:
        """
        Sync the complete model states in between the ranks
        """

531
        work_handles = []
532

533
        for t in self.module.state_dict().values():
534
            work_handles.append(
535
                dist.broadcast(t, src=self._reference_global_rank, group=self._process_group, async_op=True)
536
537
            )

538
        # gloo does not guarantee inlining like NCCL, wait for all requests
539
        if self._backend != dist.Backend.NCCL:
540
541
542
            _ = list(filter(lambda x: x.wait(), work_handles))
        elif work_handles:
            work_handles[-1].wait()
543

544
    def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
545
546
547
548
549
        """
        Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
        Adapted from ``torch.nn.distributed.DistributedDataParallel``.
        """
        for layer in module.modules():
550
            if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"):
551
                assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules"
552
553
                # device_id logic has not been handled, assume single-process single-device
                # SyncBatchNorm only supports DDP with single-process single-device anyway'
554
                # This function is removed from pytorch since 1.9.
555
                layer._specify_ddp_gpu_num(1)  # type: ignore
556
557

    def _setup_bucket_strategy(self) -> None:
558
559
560
        """Devise a bucketing strategy on a per-rank ownership level.
        These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
        This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
561
562
        """

563
564
565
        with profiler.record_function("fairscale::sdp::setup_buckets"):
            if not self._use_buckets:
                return
566

567
568
569
570
571
            # Devise the bucketing strategy. Parameters are already sorted, in that:
            # - these are only the trainable parameters, so they should produce grads
            # - they are sorted by increasing size
            self._buckets = {}
            self._should_bucket_grad = [False for _ in self._trainable_params]
572

573
574
575
            for i, param in enumerate(self._trainable_params):
                device = param.device
                dst_rank = self._trainable_param_to_rank[param]
576

577
578
                if param.device not in self._buckets.keys():
                    self._buckets[param.device] = {}
579

580
581
582
583
584
585
586
                if dst_rank not in self._buckets[param.device].keys():
                    self._buckets[param.device][dst_rank] = GradBucket(
                        self._buffer_max_size,
                        dtype=param.dtype,
                        device=param.device,
                        destination=self._local_to_global_rank[dst_rank],
                    )
587

588
589
590
591
592
                # Criteria to decide whether this parameter is to be bucketed or not:
                # - enough room in the bucket
                if self._buckets[device][dst_rank].can_add_grad_view(param):
                    self._buckets[device][dst_rank].add_grad(param)
                    self._should_bucket_grad[i] = True
593

594
            self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()]))
595

596
597
598
            # Resize the buckets to remove lost space in the end
            for bucket in self._bucket_list:
                bucket.shrink()
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616

    def _consume_work_handles(self) -> None:
        """Consume all the futures which are tied to this optimizer's buckets.
            We start from the first/older ones, since they are the most likely to be ready and non-blocking
            """

        while len(self._work_handles) > 0:
            work_handle = self._work_handles.popleft()
            work_handle.handle.wait()
            if work_handle.callback is not None:
                work_handle.callback()

    def _try_consume_work_handle(self) -> None:
        """Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
        while len(self._work_handles) > 0 and self._work_handles[0].handle.is_completed():
            work_handle = self._work_handles.popleft()
            if work_handle.callback is not None:
                work_handle.callback()
617

618
    def _flush_reduce_calls(self) -> None:
619
620
621
622
623
        for bucket in self._bucket_list:
            if not bucket.sent:
                assert bucket.buffer is not None

                # Normalize the bucket in one go
624
                bucket.buffer.mul_(self._world_size_scaling)
625
626
627
628
629

                # Reduce the bucket
                self._work_handles.append(
                    Workhandle(
                        handle=dist.reduce(
630
                            tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True,
631
632
                        ),
                        callback=None,
633
                    )
634
635
                )
                bucket.sent = True
636

637
        self._consume_work_handles()
638
639

    def _detect_train_change(self) -> bool:
640
641
642
        with profiler.record_function("fairscale::sdp::detect_train_changes"):
            # Optionally check whether the trainable parameters have changed
            trainable_mask = list(map(_trainable, self._all_params))
643

644
645
            # - one or more parameters trainability changed
            trainability_changed = trainable_mask != self._reference_trainable_mask
646

647
648
            # - the whole model is not trainable but we still have grad hooks
            trainability_changed |= not self.training and len(self._grad_hooks) > 0
649

650
651
652
653
654
            if trainability_changed:
                logging.warning(
                    "ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
                )
                self._reference_trainable_mask = trainable_mask
655
656

        return trainability_changed