sharded_ddp.py 27.6 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
8
A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted gradient
reduction automatically.
9
10
"""

11
from collections import deque
12
import contextlib
13
import functools
14
15
from itertools import chain
import logging
16
from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
17
18

import torch
19
from torch import nn
20
from torch.autograd import Variable
21
22
import torch.distributed as dist

23
from fairscale.nn.misc import GradBucket
24
from fairscale.optim import OSS
25
from fairscale.optim.utils import Workhandle, get_global_rank
26
27


28
29
30
31
def _trainable(param: torch.Tensor) -> bool:
    return param.requires_grad


32
class ShardedDataParallel(nn.Module):
33
    """ Wrap the model, and reduce the gradients to the right rank during the backward pass.
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    - the partition is given by the sharded optimizer
    - wrap the base model with a model which knows where to reduce each gradient
    - add an autograd function which calls the model grad dispatch on the way back

     Args:
        module (nn.Module):
            model to be wrapped
        sharded_optimizer (OSS, or list of OSS):
            the sharded optimizer(s) which will decide the gradient partitioning

    Keyword Args:
        process_group (group):
            torch.distributed group (default: group.WORLD)
        broadcast_buffers (bool):
            Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass.
            Same setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.
        sync_models_at_startup (bool):
            Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
            or the training restarts from a saved state
54
        reduce_buffer_size (int):
55
            The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused).
56
            this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
57
            Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
58
59
60
61
        auto_refresh_trainable (bool):
            (default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
            and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
            a parameter is frozen or unfrozen.
62
63
64
        reduce_fp16 (bool):
            cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve performance
            for multi node jobs using PyTorch AMP. The effect is similar to DDP's fp16_compress_hook_ and will also save some memory.
65

66
    .. _fp16_compress_hook: https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html?highlight=fp16#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
67
68
69
70
71
72

    .. warning:
        ShardedDDP implements gradient sharding, meaning that each rank only owns a unique shard of the model gradients
        after the backward pass, in order to save memory and some communication bandwidth.

    .. warning:
73
74
75
        As a consequence of sharding:
            * in case of gradient clipping, one has to use the `clip_grad_norm` exposed by
                the `optimizer state sharding wrapper <fairscale.optim.OSS>`
76

77
78
79
80
81
            * after loss.backward() (or equivalent) each rank will have `None` in place of some param.grad

            * Pytorch and Apex AMP implementations will hang when used in conjunction with `ShardedDDP`.
                One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`,
                compatible with PytorchAMP.
82

83
84
85
86
87
88
89
    .. warning:
        If `auto_refresh_trainable` is set to `True` (this is the default) then any trainability change in the model graph will be handled
        automatically.
        If `auto_refresh_trainable` is set to `False`, ShardedDDP will not refresh its assumptions with respect to trainable parameters
        for every forward pass, in the hope of saving some time. If some parameters are frozen or unfrozen over time, please refresh
        ShardedDDP assumptions by calling `refresh_trainable()` just after said change (before the next forward pass).

90
91
92
    """

    def __init__(
93
94
        self,
        module: nn.Module,
95
        sharded_optimizer: Union[OSS, List[OSS]],
96
        process_group: Any = None,
97
98
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
99
        reduce_buffer_size: int = 2 ** 23,
100
        auto_refresh_trainable: bool = True,
101
        reduce_fp16: bool = False,
102
103
104
105
    ):
        super().__init__()

        self.module = module
Min Xu's avatar
Min Xu committed
106
        self.sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
107
        self.enable_broadcast_buffers = broadcast_buffers
108
        self.auto_refresh_trainable = auto_refresh_trainable
109
        self.reduce_fp16 = reduce_fp16
110
        if reduce_buffer_size > 0 and reduce_fp16:
111
112
113
114
            self.reduce_fp16 = False
            logging.warning(
                "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
            )
115
116
117
118

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
        self.should_accumulate_grads = False
119
        self.accumulate_grads_flipped = False
120
121

        # Communication related attributes
122
        self.process_group = process_group if process_group is not None else dist.group.WORLD
123
        self.backend = dist.get_backend(self.process_group)
124
        self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group)  # > 0
125
        self.reference_global_rank = get_global_rank(self.process_group, 0)  # picking rank 0 as the reference
126
        self.rank = dist.get_rank(self.process_group)
127
        self.global_rank = get_global_rank(self.process_group, self.rank)
128
        self._local_to_global_rank = [
129
            get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
130
        ]
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
        devices = {p.device for p in self.module.parameters()}
        self.is_multi_device_module = len(devices) > 1
        self.device = list(devices)[0]

        distinct_device_types = {p.device.type for p in self.module.parameters()}
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
147
148
149
        # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks

        # - we build an iterator which goes through all the parameters involved globally
150
151
        self._all_params = list(
            chain(
152
                *[sum([sum(p, []) for p in optim._per_device_params.values()], []) for optim in self.sharded_optimizers]
153
            )
154
        )
155
156
157
158
        self._trainable_params: List[torch.Tensor] = []
        self._grad_to_be_reduced: List[bool] = []
        self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
        self._reference_trainable_mask = list(map(_trainable, self._all_params))
159

160
161
162
        # - setup buckets and tensor views
        model_size = sum([p.numel() for p in self.module.parameters()])
        self.buffer_max_size = min(reduce_buffer_size, model_size)
163
164
165
166
167

        if dist.get_world_size(self.process_group) == 1:
            self.buffer_max_size = 0
            logging.info("Training is not really distributed, single rank. Deactivating buckets")

168
169
170
171
172
173
174
        logging.info(
            "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
                self.buffer_max_size / 2 ** 20, model_size / 2 ** 20
            )
        )
        self.use_buckets = self.buffer_max_size > 0

175
        self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
176
        self._should_bucket_grad: List[bool] = []
177
        self._bucket_list: List[GradBucket] = []
178

179
        # - setup backward hooks which will be called by Torch's autograd in due time
180
        self._grad_accs: List[Callable] = []
181
        self._grad_hooks: List[Any] = []
182
        self._manual_reduce: List[Callable] = []
183

184
185
186
        # passing a handle to torch.nn.SyncBatchNorm layer
        self._passing_sync_batchnorm_handle(self.module)

187
188
189
190
        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()

191
        self._work_handles: Deque[Workhandle] = deque()
192
        self._bucket_flush_callback_set = False
193

194
195
196
197
198
    def forward(self, *inputs: Any, **kwargs: Any) -> Any:
        """
        Module forward pass, handles any DDP-specific work in the background. Primes the
        backward pass for gradient reduction to the proper ranks.
        """
199

200
201
202
        # Deferred initialization, or change detection
        needs_setup = len(self._grad_hooks) == 0

203
        if self.auto_refresh_trainable:
204
            # Optionally check whether the trainable parameters have changed
205
206
207
            trainable_mask = list(map(_trainable, self._all_params))
            if trainable_mask != self._reference_trainable_mask:
                logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
208
                needs_setup = True
209
210
                self._reference_trainable_mask = trainable_mask

211
212
213
        if needs_setup:
            self.refresh_trainable()

214
215
216
217
        if self.enable_broadcast_buffers:
            # NCCL communications are on a different stream, needs to be blocking
            # for the subsequent FW to be correct
            self.sync_buffers(blocking=True)
218

219
        # Reset all the grad reduce and bucket state flags
220
        self._clear_counters()
221

222
223
        # Normal FW on the base model
        return self.module(*inputs, **kwargs)
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    def to(  # type: ignore
        self,
        device: Optional[Union[int, torch.device]],
        dtype: Optional[torch.dtype] = None,
        non_blocking: bool = False,
    ) -> "ShardedDataParallel":
        """
        Moves and/or casts the parameters and buffers.

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
        floating point desired :attr:`dtype` s. In addition, this method will
        only cast the floating point parameters and buffers to :attr:`dtype`
        (if given). The integral parameters and buffers will be moved
        :attr:`device`, if that is given, but with dtypes unchanged. When
        :attr:`non_blocking` is set, it tries to convert/move asynchronously
        with respect to the host if possible, e.g., moving CPU Tensors with
        pinned memory to CUDA devices.

        .. note::
            This method modifies the module in-place.

246
247
248
249
250
        .. warning:
            Device changes are not supported, and this will raise an exception. The issue in that case is not
            really ShardedDDP, but OSS which will not be aware of the device change, and whose buffers will be
            in a broken state.

251
252
253
254
255
256
257
258
259
        Arguments:
            device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
            dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
            non_blocking (bool): make it an asynchronous call.

        Returns:
            Module: self.
        """

260
261
262
263
264
265
        assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
        assert (
            len(self.buckets.keys()) == 1
        ), "Several devices specified to begin with, incompatible with setting a single device here"

        for _device in self.buckets.keys():
266
267
            for bucket in self.buckets[_device].values():
                bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
268

269
        self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
270

271
272
273
274
    def refresh_trainable(self) -> None:
        """ If the module trainability has changed, update all the assumptions """

        # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
275
276
277
        assert not functools.reduce(
            lambda x, y: x or y, self._grad_to_be_reduced, False
        ), "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
278
279
280
281
282
283
284
285
286
287
288
289
290

        self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
        self._trainable_params.sort(key=lambda x: x.numel())
        self._grad_to_be_reduced = [True for _ in self._trainable_params]

        self._trainable_param_to_rank = {}
        for optim in self.sharded_optimizers:
            # OSS may need to change the communication pattern
            optim.refresh_trainable()

            # Update ShardedDDP given the new partitions
            for (
                device_per_rank_params
291
            ) in optim._per_device_params.values():  # all the params on this device (inc all ranks)
292
293
                for device_params in device_per_rank_params:
                    for param in filter(lambda x: x.requires_grad, device_params):
294
                        self._trainable_param_to_rank[param] = optim._param_to_rank[param]
295
296
297
298

        self._setup_bucket_strategy()
        self._setup_backward_hooks()

299
300
    def reduce(self) -> None:
        """
301
302
303
304
305
306
307
308
309
310
        This does not *need* to be called, the gradient reduction is done automatically during the BW pass.
        Use this method to reduce the gradients manually
        """

        # Check that this is not a mistake, if there's nothing to reduce
        assert functools.reduce(
            lambda x, y: x or y, self._grad_to_be_reduced, False
        ), "No grads waiting to be reduced, maybe that this was called twice or there was no BW pass ?"

        # Trigger all the current BW hooks
311
312
        self._bucket_flush_callback_set = True  # no need to flush in the end, we own the callback execution
        _ = list(map(lambda x: x(), self._manual_reduce))
313
314
315

        # Make sure that all the futures are consumed
        self._consume_work_handles()
316

317
    @torch.no_grad()
318
319
320
    def sync_buffers(self, blocking: bool = False) -> None:
        """
        Sync all the param buffers in between ranks (including for instance batch norm statistics).
321
322
323

        Arguments:
            blocking (bool): wait for the operation to conclude.
324
        """
325

326
        work_handles = []
327
328

        for buffer in self.module.buffers(recurse=True):
329
330
            work_handles.append(
                dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
331
332
            )

333
334
335
336
337
        if blocking and work_handles:
            if self.backend != dist.Backend.NCCL:
                _ = list(filter(lambda x: x.wait(), work_handles))
            else:
                work_handles[-1].wait()
338

339
340
341
342
343
344
345
346
347
    def zero_grad(self, set_to_none: bool = False) -> None:
        r"""Sets gradients of all model parameters to zero. See similar function
        under :class:`torch.optim.Optimizer` for more context.

        Arguments:
            set_to_none (bool): instead of setting to zero, set the grads to None.
                See :meth:`torch.optim.Optimizer.zero_grad` for details.
        """

348
        for index, trainable_param in enumerate(self._all_params):
349
350
351
352
353
            if set_to_none and not self._should_bucket_grad[index]:
                trainable_param.grad = None
            elif trainable_param.grad is not None:
                trainable_param.grad.zero_()

354
355
356
        for bucket in self._bucket_list:
            bucket.zero()

357
358
359
360
361
362
363
    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.module, name)

364
    @contextlib.contextmanager
365
366
    def no_sync(self) -> Generator:
        """A context manager to disable gradient synchronization."""
367
368
        old_should_accumulate_grads = self.should_accumulate_grads
        self.should_accumulate_grads = True
369
        yield
370
        self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads
371
        self.should_accumulate_grads = old_should_accumulate_grads
372

373
    @torch.no_grad()
374
    def _clear_counters(self) -> None:
375
        """Reset all the grad reduce and call counters"""
376
        self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
377
        self._bucket_flush_callback_set = False
378

379
380
        if self.use_buckets:
            for bucket in self._bucket_list:
381
                bucket.reset_checked_in()
382

383
384
385
        if not self.should_accumulate_grads:
            self.accumulate_grads_flipped = False

386
    def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
387
        """
388
389
390
391
        Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
        or contribute to a bucket and reduce when the bucket is full.

        Either way a delayed action is necessary and is passed as a callback.
392
393
        """

394
395
396
397
398
399
400
        if not self.use_buckets or not self._should_bucket_grad[index]:
            # Direct reduction
            @torch.no_grad()
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
401

402
                    if not self._bucket_flush_callback_set:
403
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
404
405
                        self._bucket_flush_callback_set = True

406
407
                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
408
409
                    param.grad.mul_(self.world_size_scaling)

410
411
412
                    if self.reduce_fp16:
                        param.grad.data = param.grad.data.half()

413
414
415
416
                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self.global_rank:
                            param.grad = None
417
418
419
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(dtype=param.dtype)
420
421

                    # Async reduce for this buffer, log the future
422
                    self._work_handles.append(
423
424
                        Workhandle(
                            handle=dist.reduce(
425
426
427
428
                                tensor=param.grad.data,
                                dst=self._local_to_global_rank[dst_rank],
                                group=self.process_group,
                                async_op=True,
429
430
431
                            ),
                            callback=cleanup,
                        )
432
                    )
433

434
                    # Opportunistically try to empty the queue, free memory
435
436
437
438
439
440
441
                    self._try_consume_work_handle()

        else:

            @torch.no_grad()
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
442

443
444
445
                if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

446
                    if not self._bucket_flush_callback_set:
447
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
448
449
                        self._bucket_flush_callback_set = True

450
451
452
                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    bucket = self.buckets[param.device][dst_rank]
453
454
                    bucket.params_checked_in += 1

455
456
457
                    if bucket.all_checked_in:
                        assert bucket.buffer is not None

458
459
460
461
462
                        # Normalize the bucket in one go
                        bucket.buffer.mul_(self.world_size_scaling)

                        # Reduce the bucket
                        bucket.sent = True
463
                        self._work_handles.append(
464
465
                            Workhandle(
                                handle=dist.reduce(
466
467
468
469
                                    tensor=bucket.buffer,
                                    dst=bucket.destination,
                                    group=self.process_group,
                                    async_op=True,
470
471
472
473
                                ),
                                callback=None,
                            )
                        )
474

475
476
                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()
477

478
        return reduce
479
480

    def _setup_backward_hooks(self) -> None:
481
        """
482
483
        Attach a reduce function to each grad-requiring parameter.
        This makes the gradient reduction automatic whenever there's a backward pass
484
        """
485

486
487
488
489
        # Detach possible pre-existing hooks
        while len(self._grad_hooks) > 0:
            self._grad_hooks.pop().remove()

490
        # Go through the parameters, attach the hook
491
        self._grad_accs = []
492
        self._manual_reduce = []
493
494
495
        for index, param in enumerate(self._trainable_params):
            if param.grad is not None and param.grad.requires_grad:
                raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
496

497
            p_tmp = param.expand_as(param)
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
            # See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
            # We're interested in the tensors which will be tracked by Autograd
            # Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
            # these do not need to be sync'ed
            if p_tmp.grad_fn is not None:
                # Register the hook to the next function in line,
                # so that the hook is fired when this grad has properly been computed
                # (by default the hook with Pytorch is a pre-grad, not a post-grad)
                grad_acc = p_tmp.grad_fn.next_functions[0][0]
                dst_rank = self._trainable_param_to_rank[param]

                reduce_function = self._get_reduce_fn(index, param, dst_rank)

                self._grad_hooks.append(grad_acc.register_hook(reduce_function))
                self._grad_accs.append(grad_acc)  # keep this hook in scope
                self._manual_reduce.append(reduce_function)
515

516
    @torch.no_grad()
517
518
519
520
521
    def _sync_params_and_buffers(self) -> None:
        """
        Sync the complete model states in between the ranks
        """

522
        work_handles = []
523
524

        for t in self.module.state_dict().values():
525
526
            work_handles.append(
                dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
527
528
            )

529
530
531
532
533
        # gloo does not guarantee inlining like NCCL, wait for all requests
        if self.backend != dist.Backend.NCCL:
            _ = list(filter(lambda x: x.wait(), work_handles))
        elif work_handles:
            work_handles[-1].wait()
534

535
    def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
536
537
538
539
540
541
        """
        Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
        Adapted from ``torch.nn.distributed.DistributedDataParallel``.
        """
        for layer in module.modules():
            if isinstance(layer, torch.nn.modules.SyncBatchNorm):
542
                assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules"
543
544
                # device_id logic has not been handled, assume single-process single-device
                # SyncBatchNorm only supports DDP with single-process single-device anyway'
545
                layer._specify_ddp_gpu_num(1)  # type: ignore
546
547

    def _setup_bucket_strategy(self) -> None:
548
549
550
        """Devise a bucketing strategy on a per-rank ownership level.
        These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
        This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
551
552
553
554
555
        """

        if not self.use_buckets:
            return

556
557
558
559
        # Devise the bucketing strategy. Parameters are already sorted, in that:
        # - these are only the trainable parameters, so they should produce grads
        # - they are sorted by increasing size
        self.buckets = {}
560
        self._should_bucket_grad = [False for _ in self._trainable_params]
561

562
        for i, param in enumerate(self._trainable_params):
563
            device = param.device
564
            dst_rank = self._trainable_param_to_rank[param]
565

566
            if param.device not in self.buckets.keys():
567
                self.buckets[param.device] = {}
568

569
570
571
572
573
574
575
            if dst_rank not in self.buckets[param.device].keys():
                self.buckets[param.device][dst_rank] = GradBucket(
                    self.buffer_max_size,
                    dtype=param.dtype,
                    device=param.device,
                    destination=self._local_to_global_rank[dst_rank],
                )
576
577
578

            # Criteria to decide whether this parameter is to be bucketed or not:
            # - enough room in the bucket
579
580
581
            if self.buckets[device][dst_rank].can_add_grad_view(param):
                self.buckets[device][dst_rank].add_grad(param)
                self._should_bucket_grad[i] = True
582

583
        self._bucket_list = list(chain(*[self.buckets[device].values() for device in self.buckets.keys()]))
584
585
586

        # Resize the buckets to remove lost space in the end
        for bucket in self._bucket_list:
587
            bucket.shrink()
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

    def _consume_work_handles(self) -> None:
        """Consume all the futures which are tied to this optimizer's buckets.
            We start from the first/older ones, since they are the most likely to be ready and non-blocking
            """

        while len(self._work_handles) > 0:
            work_handle = self._work_handles.popleft()
            work_handle.handle.wait()
            if work_handle.callback is not None:
                work_handle.callback()

    def _try_consume_work_handle(self) -> None:
        """Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
        while len(self._work_handles) > 0 and self._work_handles[0].handle.is_completed():
            work_handle = self._work_handles.popleft()
            if work_handle.callback is not None:
                work_handle.callback()
606

607
    def _flush_reduce_calls(self) -> None:
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        for bucket in self._bucket_list:
            if not bucket.sent:
                assert bucket.buffer is not None

                # Normalize the bucket in one go
                bucket.buffer.mul_(self.world_size_scaling)

                # Reduce the bucket
                self._work_handles.append(
                    Workhandle(
                        handle=dist.reduce(
                            tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
                        ),
                        callback=None,
622
                    )
623
624
                )
                bucket.sent = True
625

626
        self._consume_work_handles()