oss.py 25.7 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
from collections import OrderedDict
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
7
import copy
8
import itertools
9
from itertools import chain
10
import logging
11
12
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
13

14
import torch
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
15
import torch.distributed as dist
16
from torch.nn import Parameter
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
17
18
from torch.optim import SGD, Optimizer

19
from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device
20

21
22
__all__ = ["OSS"]

23
if TYPE_CHECKING:  # pragma: no cover
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
24
25
26
27
28
29
30
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any


class OSS(Optimizer):
    """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
31
    optimizer and shards its state as described by ZeRO_.
32
33
34
    ::

        opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
35
36
37

    .. _ZeRO: https://arxiv.org/abs/1910.02054

38
39
40
    We use a greedy algorithm to pack a number of parameters
    at each rank. Each parameter belongs to a single rank and
    is not divided among rank.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
41

42
43
44
    After each rank completed their parameter update, they broadcast
    the new version of the parameters to all other ranks to synchronize
    the parameters for next round forward/backward computation.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
45
46
47
48
49
50
51
52
53

    Args:
        params (list of tensors):
            parameters to be optimized
    Keyword Args:
        optim (torch.nn.Optimizer):
            optimizer to shard (default: SGD)
        group (group):
            torch.distributed group (default: group.WORLD)
54
55
        broadcast_buffer_size (int):
            the size of the buffer used to batch the small parameter tensors (default 128k).
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
56
57
    """

58
    #: The optimizer used for a given shard
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
59
    optim: Optimizer
60

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
61
62
    in_super_constructor: bool

63
64
65
66
67
68
69
70
    def __init__(
        self,
        params: _params_t,
        optim: Type[Optimizer] = SGD,
        group: Optional[Any] = None,
        broadcast_buffer_size: int = 2 ** 17,
        **default: Any,
    ):
71
        # Hold all the model params in the root .param_groups
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
72
        self.in_super_constructor = True
73
        super().__init__(params, default)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
74
75
        self.in_super_constructor = False

76
        # Partition information. lazy evaluation, computed when requested
77
        self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict()  # device, rank, params
78
79
80
        self._param_rank: Dict[torch.Tensor, int] = {}
        self._partition_parameters: List[List[dict]] = []

81
        # Build the wrapped optimizer, responsible for a shard of the params
82
83
84
        self.group = group if group is not None else dist.group.WORLD
        self.world_size = dist.get_world_size(self.group)
        self.rank = dist.get_rank(self.group)
85
86
        self.global_rank = self.get_global_rank(self.group, self.rank)

87
        self.optim = optim(self.partition_parameters()[self.rank], **default)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
88

89
        # - Sync local and global param_groups keys
90
        for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
91
92
93
            for key, value in local_group.items():
                if key != "params":
                    global_group[key] = value
94

95
96
97
98
99
        #  Optional consolidated optimizer state
        self._all_states: List[Dict[str, Any]] = []

        # Current default device is set by the parameters allocated to this rank
        self._device = self.partition_parameters()[self.rank][0]["params"][0].device
100
        self.buckets: Dict[torch.device, List[Bucket]] = {}
101
102
        for device, per_device in self.per_device_params.items():
            # Allocate one buffer per rank and per device to group the small parameters
103
104
            self.buckets[device] = [
                Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
105
106
                for _ in range(len(per_device))
            ]
107
108
109
110
        self.should_bucket_param: Dict[torch.Tensor, bool] = {}
        self.work_handles: List[Workhandle] = []
        self._max_work_handles = -1
        self._setup_bucket_strategy()
111

112
    # Partition helpers
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
113
    def partition_parameters(self) -> List[List[dict]]:
114
        """Partitions parameters across distributed data parallel ranks.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
115
116
117
118
119
120

        Returns a list of param_groups (which is a list of dict) where each
        element of the list contains the param_groups for a rank. Element 0
        corresponds to rank 0, etc. We need all the ranks for the broadcast
        inside step().
        """
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        if len(self._partition_parameters) == 0:
            self._partition_parameters = [list() for _ in range(self.world_size)]
            sizes = [0] * self.world_size
            for param_group in self.param_groups:
                param_lists: List[List] = [list() for _ in range(self.world_size)]
                for param in param_group["params"]:
                    # Add this param to rank with smallest size.
                    rank = sizes.index(min(sizes))
                    param_lists[rank].append(param)
                    sizes[rank] += param.numel()

                for rank, params in enumerate(param_lists):
                    param_group_rank = copy.copy(param_group)
                    param_group_rank["params"] = params
                    self._partition_parameters[rank].append(param_group_rank)

        return self._partition_parameters

    @property
140
141
    def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
        """Sorted list of all the params, first per device then per rank.
142

143
144
        Within a list params are sorted per number of elements to allow for an easy bucketing.
        """
145
        if len(self._per_device_params) == 0:
146
147
148
            # Go through all params, log them per device
            # The ordering is important here, needs to be the same on all ranks
            # So that ulterior broadcast calls are matching
149
150
151
            for param_group in self.param_groups:
                for param in param_group["params"]:
                    device = param.device
152
153
154
155
156
                    if self._per_device_params.get(device) is None:
                        self._per_device_params[device] = [[] for _ in range(self.world_size)]
                    self._per_device_params[device][self.param_to_rank[param]] += [param]

            # Sort param_lists by size
157
158
159
            for device in self._per_device_params.keys():
                for rank_params in self._per_device_params[device]:
                    rank_params.sort(key=lambda x: x.numel())
160
161
162
163
164

        return self._per_device_params

    @property
    def param_to_rank(self) -> Dict[torch.Tensor, int]:
165
        """param to data parallel rank"""
166
167
168
169
170
        if len(self._param_rank) == 0:
            for rank, param_groups in enumerate(self.partition_parameters()):
                for param_group in param_groups:
                    for param in param_group["params"]:
                        self._param_rank[param] = rank
171
172
173

            logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))

174
        return self._param_rank
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
175

176
177
178
    # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
    # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
    def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
179
180
181
182
183
184
185
186
        """Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.

        .. note: Any extra parameter is passed to the base optimizer as-is"""

187
188
        # Sync oss param_groups attributes in case they've been updated by a scheduler.
        self._sync_param_groups()
189

190
191
192
193
194
        # Run the optimizer step on this shard only:
        if closure is not None:
            loss = self.optim.step(closure=closure, **kwargs)  # type: ignore
        else:
            loss = self.optim.step(**kwargs)
195

196
197
198
        # Depending on the DDP engine used, gradients specific to other ranks may still be loaded
        self._free_other_grads()

199
        # Sync all the updated shards in between the ranks
200
        self._broadcast_params()
201

202
203
204
        # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
        self._sync_param_groups(local_to_global=True)

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
205
206
        return loss

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    def clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor:
        """
        Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
        concatenated into a single vector. Gradients are modified in-place.

        Arguments:
            max_norm (float or int): max norm of the gradients
            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.

        Returns:
            Total norm of the parameters (viewed as a single vector).

        .. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
            under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
            in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters

        .. warning: This needs to be called on all ranks, since synchronization primitives will be used

        .. warning: Model paralelism -groups other than world- are not yet supported
        """

        if self.group != dist.group.WORLD:
            raise NotImplementedError("Clip norm not yet supported for model parallelism (coming soon!)")

        # Compute the max norm for this shards's worth of gradients
        max_norm = float(max_norm)
        norm_type = float(norm_type)

        # Filter out the grad-less params, concatenate params from all devices
        local_params = itertools.chain(
            *[
                list(filter(lambda x: x.grad is not None, device_params[self.rank]))
                for device_params in self.per_device_params.values()
            ]
        )

        # Compute the norm on this grad set,
        # then sync all the norms from all ranks
        if norm_type == inf:
            total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params)  # type: ignore
            dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group)
        else:
            local_norm = torch.norm(
                input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type).to(self._device) for p in local_params]),  # type: ignore
                p=norm_type,
            )

            # local norm result can be accumulated with the remote ones if put to the right power
            # n_i = sum_rank(a^p)^1/p
            # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
            total_norm = local_norm ** norm_type
            dist.all_reduce(total_norm, group=self.group)
            total_norm = total_norm ** (1.0 / norm_type)

        clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)

        if clip_coef < 1:
            for device, device_params in self.per_device_params.items():
                for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
                    p.grad.detach().mul_(clip_coef.to(device))  # type: ignore

        return total_norm

    # State dict interfaces
271
    def local_state_dict(self) -> dict:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
272
273
274
275
276
277
278
279
280
281
        """Gets this rank's state_dict.

        Returns:
            The state of the optimizer as a :class:`dict`.
            It contains two entries:

            * state - a dict holding current optimization state. Its content
                differs between optimizer classes.
            * param_groups - a dict containing all parameter groups
        """
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
282
283
        return self.optim.state_dict()

284
    def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
285
        """Update the consolidated state_dict list, one per rank.
286

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
287
        .. warning: This needs to be called on all replicas"""
288

289
290
        # Sync lr and other attributes in case its been updated
        self._sync_param_groups()
291

292
293
294
        if self.rank == recipient_rank:
            # Pull the sharded state from all the other replicas
            # Store all the states in order, rank by rank
295
            logging.debug("Pulling the sharded optimizer state from all replicas")
296
297
298
299
300
301
            self._all_states = self._collect_sharded_states()
        else:
            # Acknowledge broadcasts, and send this rank's shard when needed
            self._broadcast_state_dict()

    def state_dict(self) -> Dict[str, Any]:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
302
303
304
305
        """Return the last known global optimizer state, which consist of a list of the shards.

        .. warning:
            If the state has not been consolidated, this returns a shard's worth, not the global state.
306

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
307
308
309
        .. warning:
            Returning the global state is limited to the replica which was responsible for the consolidation.
            The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
310
311
        """

312
313
314
315
316
317
        if len(self._all_states) == 0:
            logging.warning("Optimizer state has not been consolidated. Returning the local state")
            logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state")
            state_dict = self.local_state_dict()
            state_dict["local_state_dict"] = True
            return state_dict
318

319
320
321
322
323
324
325
326
327
328
329
        # Flatten the param_groups, save the partition which logs the rank <> shard correspondence
        partition: List[Tuple[int, int]] = []
        param_groups: List[Dict[Any, Any]] = []

        start = 0
        for i, s in enumerate(self._all_states):
            param_groups.extend(s["param_groups"])
            end = start + len(s["param_groups"])
            partition.append((start, end))
            start = end

330
331
        return {
            "state": [s["state"] for s in self._all_states],
332
333
            "param_groups": param_groups,
            "partition": partition,
334
            "local_state_dict": False,
335
        }
336

337
338
339
340
341
342
343
344
345
346
347
348
    @staticmethod
    def rank_local_state_dict(rank: int, state_dict: dict) -> dict:
        """Returns the local_state_dict for a given rank.

        Arguments:
            rank (int): rank to get local_state_dict for
            state_dict (dict): global state_dict
        """
        # Get this optimizer's param_groups shard
        param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
        return {"state": state_dict["state"][rank], "param_groups": param_groups}

349
    def load_local_state_dict(self, state_dict: dict) -> None:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
350
351
352
353
        """Loads this rank's state_dict.

        .. warning: This is not meant to load the global state dict.
        """
354

355
        self.optim.load_state_dict(state_dict)
356

357
358
359
360
361
362
363
364
365
366
367
368
369
        # Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706)
        # Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
        groups = self.optim.param_groups
        saved_groups = state_dict["param_groups"]
        id_map = {
            old_id: p
            for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups)))
        }
        for k, v in state_dict["state"].items():
            if k in id_map:
                param = id_map[k]
                self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)

370
        # Restore the global param_groups (the params themselves are already correct)
371
        for global_group, local_group in zip(self.param_groups, groups):
372
373
374
            for k, v in local_group.items():
                if k != "params":
                    global_group[k] = v
375

376
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
377
378
379
380
381
382
        """Restore the global parameter groups as well as the shard.

        Arguments:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`
        """
383

384
385
386
387
388
        # Check whether we got a local or global dict
        if state_dict["local_state_dict"]:
            self.load_local_state_dict(state_dict)
        else:
            # Dispatch this rank's state dictionary to the wrapped shard optimizer
389
            self.load_local_state_dict(OSS.rank_local_state_dict(self.rank, state_dict))
390

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
391
    def add_param_group(self, param_group: dict) -> None:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
392
393
394
395
396
397
398
399
400
401
402
403
        """Add a param group to the :class:`Optimizer` s `param_groups`.

        This can be useful when fine tuning a pre-trained network as frozen layers can be made
        trainable and added to the :class:`Optimizer` as training progresses.

        Arguments:
            param_group (dict): Specifies what Tensors should be optimized along with group
            specific optimization options

        .. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
        """

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
404
405
        super().add_param_group(param_group)
        if not self.in_super_constructor:
406
407
408
409
            # Force a re-partitioning
            self._partition_parameters.clear()
            self._per_device_params.clear()
            self._param_rank.clear()
410

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
411
412
413
            param_groups = self.partition_parameters()[self.rank]
            if len(param_groups) == len(self.optim.param_groups) + 1:
                self.optim.add_param_group(param_groups[-1])
414

415
416
417
418
419
420
421
422
    @staticmethod
    def get_global_rank(group: Any, rank: int) -> int:
        if group is dist.group.WORLD:
            return rank
        else:
            global_rank = dist.distributed_c10d._get_global_rank(group, rank)
        return global_rank

423
424
425
426
427
428
    def _sync_param_groups(self, local_to_global: bool = False) -> None:
        """Sync learning rate and other optimizer attributes (needed to support schedulers).
        If the global param groups have been altered, and we want to make sure that the
        wrapped optimizer uses the up to date version.
        Conversely if the wrapped optimizer has new keys, we expose them through the global param groups"""

429
        for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
430
431
432
433
434
            # Sync everything but the parameters
            for k in filter(lambda x: x != "params", local_group.keys()):
                if local_to_global:
                    global_group[k] = local_group[k]
                elif k in global_group.keys():
435
                    local_group[k] = global_group[k]
436

437
    def _collect_sharded_states(self) -> List[Dict[str, Any]]:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
438
        """Collect all the state shards, in CPU memory."""
439
440
441
        empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
        all_states: List[Dict[str, Any]] = []

442
        for rank in range(self.world_size):
443
444
445
446
447
448
449
            if rank == self.rank:
                logging.debug("Saving self state")
                all_states.append(
                    recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
                )

                # Sync with other replicas
450
                broadcast_object(empty_buffer, src_rank=self.global_rank, group=self.group, dist_device=self._device)
451
452
            else:
                # Fetch the optim state from the other replicas
453
                global_rank = self.get_global_rank(self.group, rank)
454
                replica_state = broadcast_object(
455
                    empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
456
457
458
459
460
461
462
463
464
465
466
                )

                all_states.append(
                    recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
                )

                logging.debug("State from rank %s received", rank)

        return all_states

    def _broadcast_state_dict(self) -> None:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
467
        """Broadcast this rank's state shard, discard others"""
468
469
        empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)

470
        for rank in range(self.world_size):
471
472
473
            if rank == self.rank:
                # Send the state to the reference replica
                logging.debug(
474
                    "Sending the sharded optimizer state to the reference replica from rank %s", rank,
475
                )
476
477
478
                broadcast_object(
                    self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
                )
479
            else:
480
                global_rank = self.get_global_rank(self.group, rank)
481
                # Discard this tensor/rank, broadcast necessary for syncing
482
                broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
483
484
485
486

    def _free_other_grads(self) -> None:
        """Free all the gradients only useful for the other ranks
        """
487
488
        for rank, partition in enumerate(self.partition_parameters()):
            if rank == self.rank:
489
490
491
492
493
                continue

            for p in partition:
                for t in p["params"]:
                    t.grad = None
494

495
    def _broadcast_params(self) -> None:
496
        """Helper function to broadcast all the parameters from a given device"""
497
498
499
500
501
502
503
504
505
506

        # The unroll callback is called when the broadcast is done.
        # If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
        # onto the corresponding parameters.
        def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
            def unroll() -> None:
                if src_rank != self.rank:
                    for flat in bucket.params:
                        flat.param.data.copy_(
                            bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
507
508
                        )

509
                bucket.reset()
510

511
            return unroll
512

513
514
515
516
517
518
519
        with torch.no_grad():
            for (
                device,
                device_params,
            ) in self.per_device_params.items():  # all the params on this device (inc all ranks)

                buckets = self.buckets[device]
520

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                # Bucket and issue all the async calls
                for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
                    global_src_rank = self.get_global_rank(self.group, src_rank)

                    for param in params:
                        # Bucket broadcast
                        if self.should_bucket_param[param]:
                            assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
                                bucket.max_size,
                                bucket.current_offset,
                                param.numel(),
                            )

                            if bucket.full():
                                self.work_handles.append(
                                    Workhandle(
                                        handle=dist.broadcast(
                                            tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
                                        ),
                                        callback=get_unroll_callback(src_rank, bucket),
                                    )
                                )

                        # Direct
                        else:
                            self.work_handles.append(
                                Workhandle(
                                    handle=dist.broadcast(
                                        tensor=param.data, src=global_src_rank, group=self.group, async_op=True
                                    ),
                                    callback=None,
                                )
                            )

        self._consume_work_handles()

    def _consume_work_handles(self) -> None:
        """ Consume all the futures which are tied to this optimizer's buckets.
        We start from the first/older ones, since they are the most likely to be ready and non-blocking
        """

        for work_handle in self.work_handles:
            work_handle.handle.wait()
            if work_handle.callback is not None:
                work_handle.callback()

        self.work_handles.clear()

    def _setup_bucket_strategy(self) -> None:
        """  Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
        (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
        over the wire.

        Generating the partition once and for all allows us to save some time at runtime, and to know when all the
        network requests have been issued.
        """
577

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        for device, per_rank_params in self.per_device_params.items():
            for dst_rank, params in enumerate(per_rank_params):
                offset = 0
                bucket_size = self.buckets[device][dst_rank].max_size

                for param in params:
                    if (offset + param.numel()) < bucket_size:
                        # This parameter is small enough to fit in the remaining size of the bucket
                        self.should_bucket_param[param] = True
                        offset += param.numel()
                    else:
                        # The parameters are sorted by size, so all the following parameters
                        # will be too big and can be skipped
                        self.should_bucket_param[param] = False

                # Register the max offset for this buffer
                self.buckets[device][dst_rank].max_offset = offset

        # Determine the max work handles in flight:
        # - all the direct reduce/broadcast + 1 bucket
        self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1