fully_sharded_data_parallel.py 84.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
from enum import Enum, auto
import functools
10
import logging
11
from math import inf
12
import time
Min Xu's avatar
Min Xu committed
13
import traceback
14
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
15
16
17
18
19
20
21
22
23
24

import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F

from fairscale.nn.misc import FlattenParamsWrapper
Min Xu's avatar
Min Xu committed
25
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
26
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
27
from fairscale.utils.containers import apply_to_tensors
28
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
29
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
30
from fairscale.utils.state_dict import replace_by_prefix_
31

32
33
from . import fsdp_optim_utils as ou

34
35
36
37
38
39
40
41
42
if TYPE_CHECKING:
    from collections import OrderedDict  # noqa: F401


class TrainingState(Enum):
    """
    Simple enum to indicate what state FSDP is in. Used for asserting
    to make sure APIs are called in the correct state.

43
44
45
46
47
48
49
    ..note::

        BACKWARD_PRE and BACKWARD_POST states are used to ensure we
        receives backward hooks in the correct order. It is used to catch
        unexpected order of hooks being called (likely due to our
        hook registration logic or autograd engine logic changes).

50
51
52
53
54
55
56
57
58
59
60
    TODO (Min): It would be nice to capture the stepping state as well.
        Maybe we can use the model.zero_grad() call, but not sure if it
        is called if optim.zero_grad() is used instead.
        It would be nice to have clear state transition be explicit like:

        zero_grad -> fwd -> bwd -> optionally accum grad by repeating
        fwd/bwd -> stepping -> loop back to zero_grad
    """

    IDLE = auto()
    FORWARD = auto()
61
62
    BACKWARD_PRE = auto()
    BACKWARD_POST = auto()
63
    SUMMON_FULL_PARAMS = auto()
64
65
66
67
68
69
70
71
72
73
74
75


class FullyShardedDataParallel(nn.Module):
    """
    A wrapper for sharding Module parameters across data parallel workers. This
    is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.

    .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
    .. _DeepSpeed: https://www.deepspeed.ai/

    Usage::

76
77
        import torch
        from fairscale.nn.data_parallel import FullyShardedDataParallel
Myle Ott's avatar
Myle Ott committed
78
        torch.cuda.set_device(device_id)
79
80
81
82
83
84
85
86
87
        sharded_module = FullyShardedDataParallel(my_module)
        optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
        x = sharded_module(x, y=3, z=torch.Tensor([1]))
        loss = x.sum()
        loss.backward()
        optim.step()

    It is also possible to shard individual layers separately and have an outer
    wrapper handle any leftover parameters. This can be helpful to further
Myle Ott's avatar
Myle Ott committed
88
89
90
    reduce GPU memory usage, reduce system memory usage when initializing large
    models and to improve training speed by overlapping the all-gather step
    across the forward pass. For example::
91

92
        import torch
Sam Shleifer's avatar
Sam Shleifer committed
93
94
        from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
        from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
95
96
        fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
        with enable_wrap(**fsdp_params):
Sam Shleifer's avatar
Sam Shleifer committed
97
98
99
100
            # Wraps layer in FSDP by default if within context
            self.l1 = wrap(torch.nn.Linear(5, 5))
            assert isinstance(self.l1, FSDP)
            # Separately Wraps children modules with more than 1e8 params
101
102
103
            large_tfmr = torch.nn.Transformer(d_model=2048, encoder_layers=12, decoder_layers=12)
            self.l2 = auto_wrap(large_tfmr, min_num_params=1e8)
            assert isinstance(self.l2, FSDP)
104

Myle Ott's avatar
Myle Ott committed
105
106
107
108
109
110
    .. warning::

        The optimizer must be initialized *after* the module has been wrapped,
        since FSDP will shard parameters in-place and this will break any
        previously initialized optimizers.

111
112
113
114
115
116
117
    .. warning::

        If you wrap every parameter inside a nested FSDP and leaving the outer
        FSDP empty without any parameter, checkpointing activation may trigger
        an assert on the backward pass. The solution is to leave some parameters
        to the outer FSDP.

118
119
120
121
122
123
    .. warning::

        If activation checkpointing is used with FSDP, it is strongly encouraged
        to use ``checkpoint_wrapper`` function from FairScale instead of the
        ``checkpoint`` function from PyTorch.

124
    Args:
Min Xu's avatar
Min Xu committed
125
        module (nn.Module):
126
            module to be wrapped with FullyShardedDataParallel.
Min Xu's avatar
Min Xu committed
127
128
129
        process_group (Optional):
            process group for sharding
        reshard_after_forward (bool, Optional):
Myle Ott's avatar
Myle Ott committed
130
131
132
            if ``True``, reshard parameters after the forward pass. This saves
            memory but slows training. This is only relevant when resharding
            individual layers.
Min Xu's avatar
Min Xu committed
133
        mixed_precision (bool, Optional):
Myle Ott's avatar
Myle Ott committed
134
135
136
            if ``True``, inputs, activations and gradients will be kept in FP16;
            computation and communication will occur in FP16; and a (sharded)
            master copy of the model weights will be maintained in FP32.
Min Xu's avatar
Min Xu committed
137
        fp32_reduce_scatter (bool, Optional):
Myle Ott's avatar
Myle Ott committed
138
139
            if ``True``, then reduce-scatter gradients in FP32. This is only
            relevant when *``mixed_precision``* is ``True``.
Min Xu's avatar
Min Xu committed
140
        flatten_parameters (bool, Optional):
Myle Ott's avatar
Myle Ott committed
141
142
            if ``True``, flatten parameters into a single contiguous tensor,
            which improves training speed.
Min Xu's avatar
Min Xu committed
143
        cpu_offload (bool, Optional):
Myle Ott's avatar
Myle Ott committed
144
145
            if ``True``, offload FP32 params to CPU. This is only relevant when
            *``mixed_precision``* is ``True``.
Min Xu's avatar
Min Xu committed
146
        compute_dtype (torch.dtype, Optional):
Myle Ott's avatar
Myle Ott committed
147
148
149
            dtype for full parameters for computation. This defaults to
            ``torch.float32`` unless *``mixed_precision``* is set, in which case
            it defaults to ``torch.float16``.
150
151
        buffer_dtype (torch.dtype, Optional):
            dtype for buffers for computation. This defaults to ``compute_dtype``.
Min Xu's avatar
Min Xu committed
152
        move_grads_to_cpu (bool, Optional):
Myle Ott's avatar
Myle Ott committed
153
154
155
            move gradient shard to CPU after reduction. This is useful when
            combined with CPU-based optimizers. It defaults to the value of
            *``cpu_offload``*.
Min Xu's avatar
Min Xu committed
156
        bucket_cap_mb (int, Optional):
Myle Ott's avatar
Myle Ott committed
157
            FSDP will bucket parameters so that gradient reduction can
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            be more efficient for small parameters.
            ``bucket_cap_mb`` controls the bucket size in MegaBytes (MB). Buckets
            are sub-divided based on world_size, so the max shard size is roughly
            ``bucket_cap_mb / world_size``. There is one bucketer (with potentially
            multiple ``bucket_cap_mb`` sized buffers shared by all FSDP instances.
            Large gradient tensors are directly reduced without using the buffers.
            The buffers are there to reduce communication overhead for small tensors.
            Overlapping with computation happens due to use of a different CUDA stream
            than the computation CUDA stream. The total memory overhead per buffer is around
            ``bucket_cap_mb / world_size * (world_size + 1)``.
            The buffers are allocated during the backward pass and freed at the end
            of the backward pass to save more memory for other phases of the
            training process.
            Note, the memory vs. speed tradeoff of bucket size is very different
            from that of the DDP engine. In DDP, the buffer size ``1MB + n*cap_mb``,
            until n is big enough to cover the entire model size. The order
            of which buffer is ready there is more rigid and DDP requires all
            gradients to be computed in the backward. In FSDP, the buffer size
            does not change with model size (it changes based on number of
            <dtype, device, process_group> tuples) and gradient ready order matters
            little since FSDP has a final flush call that ensures everything is reduced
            and not all gradients need to be upfront known. Overlapping with compute is
            done differently too.
            Values <= 0 disable bucketing.
Myle Ott's avatar
Myle Ott committed
182
            Default: 25.
183
184
185
186
187
        compute_device (torch.device, Optional):
            device for computation. If not given and module params are on a CUDA
            device, the param's device will be used. If not given and module
            params are on CPU, then the current CUDA device (as indicated by
            ``torch.cuda.current_device()`` will be used.
188
189
190
191
192
193
        no_broadcast_optim_state: (bool, Optional)
            do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
            If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
            with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
            where all but a few parameters can fit on one node.
            Default: False
194
195
196
197
        state_dict_device (torch.device, Optional):
            device for parameters returned by :func:`state_dict`. If not given,
            this will default to ``compute_dtype``. Note that only the device
            type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
198
199
200
201
202
203
204
205
206
        clear_autocast_cache (bool):
            When using mixed precision training with `torch.amp.autocast`, if the model weights
            are in FP32, autocast maintains a cache for downcasted weights. The cache can cause
            GPU OOM during the forward pass. Setting this flag to true will help clearing this
            cache as inner FSDP instances finish part of the forward pass to save GPU memory.
            Default: False
        verbose (bool):
            Set this to ``True`` to turn on verbose output for model's string representation.
            Default: False
207
208
209
210
211
212
213
214
215
216
217
218
    """

    def __init__(
        self,
        module: nn.Module,
        process_group: Optional[ProcessGroup] = None,
        reshard_after_forward: bool = True,
        mixed_precision: bool = False,
        fp32_reduce_scatter: bool = False,
        flatten_parameters: bool = True,
        cpu_offload: bool = False,
        compute_dtype: Optional[torch.dtype] = None,
219
        buffer_dtype: Optional[torch.dtype] = None,
220
221
        move_grads_to_cpu: Optional[bool] = None,
        bucket_cap_mb: int = 25,
222
        compute_device: Optional[torch.device] = None,
223
        no_broadcast_optim_state: Optional[bool] = False,
224
        state_dict_device: Optional[torch.device] = None,
225
226
        clear_autocast_cache: bool = False,
        verbose: bool = False,
227
    ):
228
        init_start = time.time()
229
230
231
232
233
234
235
236
237
238
        super().__init__()
        self.process_group = process_group or dist.new_group()
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
        self.reshard_after_forward = reshard_after_forward
        self.mixed_precision = mixed_precision
        self.fp32_reduce_scatter = fp32_reduce_scatter
        self.flatten_parameters = flatten_parameters
        self.cpu_offload = cpu_offload
        self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
239
        self.buffer_dtype = buffer_dtype or self.compute_dtype
240
241
        self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
        self.bucket_cap_mb = bucket_cap_mb
242
        self.compute_device = compute_device or _get_default_cuda_device(module)
243
244
        self.uncollected_opt_state: Dict[int, Dict] = {}
        self.no_broadcast_optim_state = no_broadcast_optim_state
245
        self.state_dict_device = state_dict_device or self.compute_device
246
247
        self.clear_autocast_cache = clear_autocast_cache
        self.verbose = verbose
248

249
        self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
250
        self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
251
252

        self.numel_padded_per_param: List[int] = []
253
        self._tstart = time.time()
254
255
256
257
258
259

        if self.fp32_reduce_scatter and not self.mixed_precision:
            raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
        if self.cpu_offload and not self.mixed_precision:
            raise ValueError("cpu_offload requires mixed_precision=True")

260
261
262
263
        # skip validation if the process group was created above
        if process_group:
            validate_process_group(self.compute_device, self.process_group)

264
        # enable pytorch sync_bn just in case model contains sync_bn layers.
265
        enable_pytorch_sync_bn(module)
266
267
268
269
270
271

        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
        params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))

272
        self._has_params = len(params) > 0
273
274
275
276
        if not self._has_params:
            self.flatten_parameters = False

        if self.flatten_parameters:
277
            self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
278
            del module  # free original module in case it helps garbage collection
279
            self.params = [self._fsdp_wrapped_module.flat_param]
280
        else:
281
            self._fsdp_wrapped_module = module
282
283
284
285
286
287
288
289
290
291
292
293
294
            self.params = params

        # Shard module parameters in place
        self._shard_parameters_()

        # Make sure all parameters are sharded.
        for n, p in self.named_parameters():
            assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"

        self._reset_lazy_init()

        # Flag to indicate if we require gradient reduction in the backward
        # pass. This will be False when inside the no_sync context manager.
295
        self._require_backward_grad_sync: bool = True
296

297
        # Enum to indicate if we're in the forward/backward pass, idle, etc.
298
299
        self.training_state = TrainingState.IDLE

300
301
302
        # Flag to indicate if the full params are gathered.
        self.has_full_params: bool = False

303
304
305
306
307
308
309
310
311
        # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
        # prefix and before load_state_dict() to add it back.
        self._register_state_dict_hook(_post_state_dict_hook)
        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)

        # Flag to indicate whether state_dict() should automatically summon the
        # full params. This defaults to True, but may be set to False if the
        # user explicitly requests the local state dict via local_state_dict().
        self._return_full_state_dict = True
312
313
        init_end = time.time()

314
        logging.debug(
315
316
            f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
        )
317

318
        # Flag to guard multiple pre-backward hook being executed per iteration.
319
320
321
322
323
        # This is reset at the end of the backward pass.
        self._pre_backward_hook_has_run = False

    def _get_gradient_predivide_factor(self, world_size: int) -> float:
        factor: int = 1
324
        while world_size % factor == 0 and world_size / factor > factor:
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
            factor *= 2
        return float(factor)

    def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None:
        """Allowing user to override the pre and post divide factors.

        Args:
            pre (float): divide factor before the reduction.
            post (float): divide factor after the reduction.
            recursive (bool): recursively set it for all child FSDP instances or not.
        """
        self.assert_state(TrainingState.IDLE)
        if recursive:
            for module in self.modules():
                if isinstance(module, FullyShardedDataParallel) and module != self:
                    module.set_gradient_divide_factors(pre, post, False)
        self.gradient_predivide_factor = pre
        self.gradient_postdivide_factor = post
343

344
345
346
347
    @property
    def module(self) -> nn.Module:
        return self._fsdp_wrapped_module  # note: may be a FlattenParamsWrapper instance

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
        """
        Applies ``fn`` recursively to every submodule (as returned by
        ``.children()``) as well as self. Typical use includes initializing the
        parameters of a model.

        Compared to ``torch.nn.Module.apply``, this version additionally gathers
        the full parameters before applying ``fn``. It should not be called from
        within another ``summon_full_params`` context.

        Args:
            fn (nn.Module): function to be applied to each submodule

        Returns:
            Module: self
        """
        is_uninitialized = self._is_root is None
        self.assert_state(TrainingState.IDLE)
        with self.summon_full_params(recurse=False):
            return_value = super().apply(fn)
        # summon_full_params will call _lazy_init, which sets _is_root. However,
        # apply() may be called directly on children instances to do weight
        # init, so we should reset the _is_root flag in this case.
        if is_uninitialized and self._is_root:
            for module in self.modules():
                if isinstance(module, FullyShardedDataParallel):
                    module._reset_lazy_init()
        return return_value

    def _cast_buffers(
        self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
    ) -> None:
        """Move all buffers to the given *device* and *dtype*.

        If *device* or *dtype* are not given, then they will default to
        ``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
        case of nested FSDP instances, we will respect the child instance's
        ``compute_device`` and ``buffer_dtype`` configuration.

        Args:
            device (torch.device, Optional):
                device to cast buffers to (defaults to compute_device)
            dtype (torch.dtype, Optional):
                dtype to cast buffers to (defaults to buffer_dtype)
            memo (Set, Optional):
                set of modules that have already been processed
        """
        if memo is None:
            memo = set()
        for module in self.modules():
            if module is not self and isinstance(module, FullyShardedDataParallel):
                # Allow any child FSDP instances to handle their own buffers.
                module._cast_buffers(device=device, dtype=dtype, memo=memo)
            elif module not in memo:
                memo.add(module)
                for name, buf in module.named_buffers(recurse=False):
                    if buf is None:
                        continue
                    buf = buf.to(device=device or self.compute_device)
                    if torch.is_floating_point(buf):
                        buf = buf.to(dtype=dtype or self.buffer_dtype)
                    setattr(module, name, buf)
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    @property
    def params_with_grad(self) -> List[Parameter]:
        """[p for p in self.parameters() if p.grad is not None] """
        return [p for p in self.parameters() if p.grad is not None]

    @torch.no_grad()
    def clip_grad_norm_(
        self,
        max_norm: Union[float, int],
        norm_type: Union[float, int] = 2.0,
        # filter_params_fn: Callable[[Any], Any] = None,
    ) -> torch.Tensor:
        """
Myle Ott's avatar
Myle Ott committed
424
425
426
        Clip all gradients at this point in time. The norm is computed over all
        gradients together, as if they were concatenated into a single vector.
        Gradients are modified in-place.
427

Myle Ott's avatar
Myle Ott committed
428
        Args:
429
            max_norm (float or int): max norm of the gradients
Myle Ott's avatar
Myle Ott committed
430
431
            norm_type (float or int): type of the used p-norm. Can be ``'inf'``
                for infinity norm.
432
433
434
435

        Returns:
            Total norm of the parameters (viewed as a single vector).

Myle Ott's avatar
Myle Ott committed
436
437
438
439
440
441
        .. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
            handles the partitioning and multiple devices per rank under the
            hood. The default torch util is not applicable here, because each
            rank only has a partial view of all the grads in the model, so
            calling it in the OSS context would lead to different scaling being
            applied per subset of model parameters.
442

Myle Ott's avatar
Myle Ott committed
443
444
        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
445
        """
446
447
448
449
        # We don't call torch.cuda.synchronize() here, since clipping can be
        # inside the train loop and we probably don't want to force a GPU-CPU sync.
        # _lazy_init should be sufficient, since it will force the other streams
        # to sync with the default stream (via _wait_for_previous_optim_step).
450
        self._lazy_init()
451
        assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
452
        self.assert_state(TrainingState.IDLE)
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507

        max_norm = float(max_norm)
        norm_type = float(norm_type)
        params_with_grad = self.params_with_grad
        if not self.children_share_process_group:
            raise NotImplementedError(
                "clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work"
            )
        # Computes the max norm for this shard's gradients and sync's across workers
        local_norm = calc_grad_norm(params_with_grad, norm_type).cuda()
        if norm_type == inf:
            total_norm = local_norm
            dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
        else:
            total_norm = local_norm ** norm_type
            dist.all_reduce(total_norm, group=self.process_group)
            total_norm = total_norm ** (1.0 / norm_type)

        if self.move_grads_to_cpu:
            total_norm = total_norm.cpu()
        # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
        clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
        if clip_coef < 1:

            # multiply by clip_coef
            for p in params_with_grad:
                p.grad.detach().mul_(clip_coef.to(p.grad.device))  # type: ignore

        return total_norm

    @torch.no_grad()
    def _shard_parameters_(self) -> None:
        """
        At initialization we wrap a module with full parameters and shard the
        parameters in-place. Sharding is implemented by viewing each parameter
        as a 1D Tensor and retaining only a single slice, where the slice size
        is determined by the number of data parallel workers.

        Wrapping modules with many small parameters (or with a very large data
        parallel world size) will result in many small parameter shards and slow
        performance. In this case it's better to set *``flatten_parameters``* to
        ``True``, so that all of the small parameters in the module are combined
        into a single contiguous Tensor and sharded once.

        After this initial sharding is complete, the user can initialize a
        ``torch.optim.Optimizer`` in the usual way, i.e.::

        .. code-block:: python

            optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)

        The optimizer will see only a single slice of parameters and will thus
        allocate less memory for optimizer state, avoiding redundancy across
        data parallel workers.
        """
508
        self.numel_padded_per_param = []
509
510
511
512
513
514
515
516
517
518
519
        for p in self.params:
            assert not hasattr(p, "_is_sharded")
            assert p.is_floating_point()
            if self.mixed_precision:
                assert p.dtype == torch.float32

            # If world_size is 1, then we all-reduce grads instead of sharding.
            p._is_sharded = self.world_size > 1
            p._orig_size = p.data.size()

            if not p._is_sharded:
520
                self.numel_padded_per_param.append(0)
521
522
523
524
525
                continue
            p._is_sharded = True

            # Replace p.data with the relevant shard.
            orig_data = p.data
526
527
            p.data, num_padded = self._get_shard(p.data)
            self.numel_padded_per_param.append(num_padded)
528
            free_storage_(orig_data)
529
        assert len(self.numel_padded_per_param) == len(self.params)
530

531
532
    def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
        """Return the local shard of a full tensor."""
533
534
535
536
537
538
539
540
541
542
543
544
        # Shard using torch.chunk to match all-gather/reduce-scatter.
        chunks = list(torch.flatten(tensor).chunk(self.world_size))
        while len(chunks) < self.world_size:
            chunks.append(chunks[0].new_empty(0))

        # Determine number of padding elements.
        num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
        assert num_to_pad >= 0, num_to_pad

        shard = chunks[self.rank].clone()
        if num_to_pad > 0:
            shard = F.pad(shard, [0, num_to_pad])
545
        return shard, num_to_pad
546

547
    def extra_repr(self) -> str:
548
549
        repr = (
            f"world_size={self.world_size}, "
550
            f"flatten_parameters={self.flatten_parameters}, "
551
            f"mixed_precision={self.mixed_precision}, "
552
        )
553
554
555
556
557
558
559
560
561
562
563
564
565
        if self.verbose:
            repr = (
                f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
                f"compute_dtype={self.compute_dtype}, "
                f"buffer_dtype={self.buffer_dtype}, "
                f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
                f"compute_device={self.compute_device}"
                f"cpu_offload={self.cpu_offload}, "
                f"move_grads_to_cpu={self.move_grads_to_cpu}, "
                f"bucket_cap_mb={self.bucket_cap_mb}, "
                f"clear_autocast_cache={self.clear_autocast_cache}"
            )
        return repr
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.module, name)

    def __getstate__(self) -> Dict[str, str]:
        """Serialize the state of the current FullyShardedDataParallel instance.

        Some properties are not serializable (e.g., process groups, streams), so
        we remove them and try to reconstruct them in :func:`__setstate__`.
        """
        state = copy.copy(self.__dict__)
        state["is_sharded"] = [p._is_sharded for p in self.params]
        state["orig_sizes"] = [p._orig_size for p in self.params]
        if state["process_group"] is not None:
            state["process_group"] = "MISSING"  # process_group isn't pickleable
        self._reset_lazy_init()
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Intercept state setting and perform needed changes on params."""
        super().__setstate__(state)

        def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
            assert isinstance(p, Parameter)
            p.data = p.data.clone()  # move tensors out of shared memory
            p._is_sharded = is_sharded
            p._orig_size = size
            return p

        self.params = [
            fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
        ]
        del self.is_sharded
        del self.orig_sizes
        self._reset_lazy_init()

    # TODO (Min): figuring out how to do typing for this overloaded function.
607
    def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]":  # type: ignore
608
609
610
        """
        Returns the whole (unsharded) state of the module. Parameters are not
        sharded, so the resulting state_dict can be loaded directly by the
Myle Ott's avatar
Myle Ott committed
611
        wrapped Module without any sharding-specific logic. Returned tensors
612
        will be full precision (e.g., FP32).
Myle Ott's avatar
Myle Ott committed
613
614
615

        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
616
        """
617
618
        torch.cuda.synchronize()
        self._lazy_init()
619
        if self.mixed_precision:
620
            # Buffers dtype stays consistent with parameters.
621
            self._cast_buffers(dtype=torch.float32)
622

623
624
        if self._return_full_state_dict:
            if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
625
                with self.summon_full_params(recurse=False, volatile=True):
626
627
628
629
630
631
632
633
634
                    state_dict = super().state_dict(*args, **kwargs)
            else:
                state_dict = super().state_dict(*args, **kwargs)
        else:
            if self.flatten_parameters:
                assert isinstance(self.module, FlattenParamsWrapper)
                state_dict = self.module.flat_state_dict(*args, **kwargs)
            else:
                state_dict = super().state_dict(*args, **kwargs)
635

636
637
638
639
        if self.cpu_offload:
            for k in state_dict.keys():
                state_dict[k] = state_dict[k].cpu()

640
        if self.mixed_precision:
641
642
            # In case we are in mixed precision, restore buffers back to buffer_dtype.
            self._cast_buffers()
643
644
645
646
647
648
649
650
651
        return state_dict

    # TODO (Min): figuring out how to do typing for this overloaded function.
    def local_state_dict(self, *args, **kwargs):  # type: ignore
        """
        Returns the local (sharded) state of the module. Parameters are sharded,
        so the resulting state_dict can only be loaded after the Module has been
        wrapped with FullyShardedDataParallel.
        """
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        with contextlib.ExitStack() as stack:
            # Tell any nested FSDP instances not to auto summon full params.
            for module in self.modules():  # includes self
                if isinstance(module, FullyShardedDataParallel):
                    stack.enter_context(module._no_return_full_state_dict())
            return self.state_dict(*args, **kwargs)

    @contextlib.contextmanager
    def _no_return_full_state_dict(self) -> Generator:
        backup = self._return_full_state_dict
        self._return_full_state_dict = False
        try:
            yield
        finally:
            self._return_full_state_dict = backup
667
668
669
670

    def load_state_dict(
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
Myle Ott's avatar
Myle Ott committed
671
672
673
674
675
676
        """
        Load a whole (unsharded) state_dict.

        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
        """
677
678
679
680
681
682
683
        if self._return_full_state_dict:
            with self.summon_full_params():
                return self.module.load_state_dict(state_dict, strict)
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            return self.module.load_state_dict(state_dict, strict)
684
685
686
687
688

    def load_local_state_dict(
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
        """Load a local (sharded) state_dict."""
689
690
691
692
693
694
695
        with contextlib.ExitStack() as stack:
            # Tell any nested FSDP instances not to auto summon full params.
            for module in self.modules():  # includes self
                if isinstance(module, FullyShardedDataParallel):
                    stack.enter_context(module._no_return_full_state_dict())
            output = self.load_state_dict(state_dict, strict)
        return output
696
697
698
699
700
701
702

    @contextlib.contextmanager
    def no_sync(self) -> Generator:
        """
        A context manager to disable gradient synchronizations across DDP
        processes. Within this context, gradients will be accumulated on module
        variables, which will later be synchronized in the first
703
704
705
706
707
        forward-backward pass after exiting the context.

        .. note:: This may result in higher memory usage because we will
            accumulate the full model gradients (instead of gradient shards)
            until the eventual sync.
708
709
710
711
712
713
714
715
716
        """
        self._lazy_init()
        assert self._is_root, "no_sync on inner FSDP is not supported"
        self.assert_state(TrainingState.IDLE)
        # This instance may wrap other FullyShardedDataParallel instances and we
        # need to set all of them to accumulate gradients.
        old_flags = []
        for m in self.modules():  # includes self
            if isinstance(m, FullyShardedDataParallel):
717
718
                old_flags.append((m, m._require_backward_grad_sync))
                m._require_backward_grad_sync = False
719
720
721
722
        try:
            yield
        finally:
            for m, old_flag in old_flags:
723
                m._require_backward_grad_sync = old_flag
724

725
    @contextlib.contextmanager
726
    def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
727
        """
728
729
        A context manager to expose full params for the current FSDP instance.
        Can be useful *after* forward/backward for a model to get the params for
730
731
        additional processing or checking. Parameters will be gathered in full
        precision (e.g., FP32).
732

733
        .. note:: This can be used on inner FSDPs.
734

735
736
        .. note:: This can *not* be used within a forward or backward pass. Nor
            can forward and backward be started from within this context.
737
738
739
740
741
742
743
744
745
746
747
748
749

        .. note:: The full parameters will be freed after the context manager
            exits; it is up to the caller to clone them if needed.

        .. note:: The full parameters can be modified, but only the portion
            corresponding to the local param shard will persist after the
            context manager exits (unless ``volatile=True``, in which case there
            are no guarantees about persistence).

        Args:
            recurse (bool, Optional): recursively summon all params for nested
                FSDP instances (default: True)
            volatile (bool, Optional): if ``True``, modifications to params are
750
                not guaranteed to persist after the context manager exists;
751
                enabling this can be slightly more efficient (default: False)
752
        """
753
754
        if recurse:
            with contextlib.ExitStack() as stack:
755
                # Summon all params for any nested FSDP instances.
756
757
                for module in self.modules():
                    if isinstance(module, FullyShardedDataParallel):
758
759
                        stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
                # Yield to the caller, with full params in all nested instances.
760
                yield
761
            # Exiting from the ExitStack will re-shard params.
762
763
764
765
766
767
768
769
            return
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            self.assert_state(TrainingState.IDLE)
            # Set the state so that we assert when trying to go into
            # forward/backward.
            self.training_state = TrainingState.SUMMON_FULL_PARAMS
770
            full_tensors = self._rebuild_full_params(force_full_precision=True)
771
            assert full_tensors is not None
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
            with contextlib.ExitStack() as stack:
                if self.flatten_parameters and self.module.is_flattened:
                    # Update flattened views to point to fully-sized tensors. We
                    # use self.params[0] instead of full_tensors since the
                    # latter may contain padding.
                    assert len(self.params) == 1
                    assert isinstance(self.module, FlattenParamsWrapper)
                    stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0]))
                try:
                    yield
                finally:
                    stack.close()
                    assert len(full_tensors) == len(self.params)
                    for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors):
                        if not volatile:
                            # Copy any changes made to the full params back into
                            # the corresponding local shards.
789
                            local_shard, _ = self._get_shard(full_tensor)
790
791
792
                            p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
                        if safe_to_free:
                            free_storage_(full_tensor)
793
                    self.has_full_params = False
794
795
                    self._use_fp32_param_shard()
                    self.training_state = TrainingState.IDLE
796

797
798
799
    def _reset_lazy_init(self) -> None:
        """Reset instance so :func:`_lazy_init` will run on the next forward."""
        self._is_root: Optional[bool] = None
800
        self._queue_wait_for_post_backward_closure: Optional[Callable] = None
801
802
        self._streams: Dict[str, torch.cuda.Stream] = {}
        self._reducer: Optional[ReduceScatterBucketer] = None
803
804
805
        for p in self.params:
            if hasattr(p, "_fp32_shard"):
                del p._fp32_shard  # reset _init_param_attributes
806
807
808

    def _lazy_init(self) -> None:
        """Initialization steps that should happen lazily, typically right
809
810
           before the first forward pass.
        """
811
812
813
814
815
816
817
818
819
820
821
822
823
        # Initialize param attributes lazily, in case the param's dtype or
        # device changes after __init__.
        for p in self.params:
            self._init_param_attributes(p)

        # Initialize _is_root and setup streams. These steps would ideally
        # happen in __init__, but _is_root can only be determined after the
        # entire model hierarchy is setup, thus we run it lazily.
        if self._is_root is None:
            self._set_is_root()
            self._setup_streams()

        if self._is_root:
824
825
826
827
            # Buffers stay on GPU, and don't get sharded. Since _cast_buffers
            # applies recursively, we only call this from the root instance.
            self._cast_buffers()

828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
            # Don't free the full params for the outer-most (root) instance,
            # since those params will be needed immediately after for the
            # backward pass.
            self.reshard_after_forward = False

            # Due to the use of streams, we need to make sure the previous
            # ``optim.step()`` is done before we all-gather parameters.
            self._wait_for_previous_optim_step()

    @torch.no_grad()
    def _init_param_attributes(self, p: Parameter) -> None:
        """
        We manage several attributes on each Parameter instance. The first two
        are set by :func:`_shard_parameters_`:

            ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
                if the Parameter is intentionally not sharded (in which case we
                will all-reduce grads for this param).
            ``_orig_size``: the size of the original Parameter (before sharding)

        The remaining attributes are set here:
            ``_fp32_shard``: a single shard of the parameters in full precision
                (typically FP32, but this is dependent on the dtype of the model
                as it's passed in by the user). This can be on CPU or GPU
                depending on the value of *``cpu_offload``*.
            ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be
                a single shard of the parameters in FP16, used for all-gather.
            ``_full_param_padded``: the full weight (padded to be evenly
                divisible by ``world_size``), used for computation in the
                forward and backward pass. This will be resized in place and
                only materialized (via all-gather) as needed.
        """
        assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size")
        if hasattr(p, "_fp32_shard"):
            return

        # A single shard of the parameters in full precision.
        p._fp32_shard = p.data

        if self.mixed_precision:
            assert p._fp32_shard.dtype == torch.float32

            if self.cpu_offload:
                assert p._fp32_shard.device == torch.device("cpu")
                # If we plan to keep the FP32 parameters on CPU, then pinning
                # memory allows us to later use non-blocking transfers when moving
                # the FP32 param shard to compute_device.
                p._fp32_shard = p._fp32_shard.pin_memory()
                p.data = p._fp32_shard

            # In mixed precision mode, we maintain a reduced precision
            # (typically FP16) parameter shard on compute_device for performing
            # the computation in the forward/backward pass. We resize the
            # storage to size 0 at init (here) and re-materialize (by copying
            # from _fp32_shard) as needed.
883
            p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
884
885
886
887
888
889
890
891
892
893
894
895
            free_storage_(p._fp16_shard)
        else:
            p._fp16_shard = None  # use _fp32_shard

        # We also maintain a full-sized parameter of type self.compute_dtype
        # (FP16 for mixed_precision or FP32 otherwise). We resize the
        # storage to size 0 at init (here) and only materialize as needed. The
        # storage may contain padding elements so that it is evenly divisible by
        # world_size, although these padding elements will be removed before the
        # relevant computation.
        if p._is_sharded:
            p._full_param_padded = torch.zeros(
896
                p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
897
898
899
900
901
902
903
904
905
906
907
908
            )
            free_storage_(p._full_param_padded)

        if self.move_grads_to_cpu:
            # We can optionally move the grad shard to CPU during the backward
            # pass. In this case, it's important to pre-allocate the CPU grad
            # shard in pinned memory so that we can do a non-blocking transfer.
            p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()

    def _set_is_root(self) -> None:
        """If ``True``, implies that no other :class:`FullyShardedDataParallel`
        instance wraps this one. Called once by :func:`_lazy_init`.
Myle Ott's avatar
Myle Ott committed
909
910
911
912
        Also sets self.children_share_process_group = True if all child
        instances share the same process group. If some child instances use a
        different process group, self.clip_grad_norm_ will raise an error.
        """
913
914
        if self._is_root is not None:
            return
915
        # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
916
        self._is_root = True
917
918
919
920
        assert self._queue_wait_for_post_backward_closure is None
        self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
        # As the root, we now set all children instances to False and
        # give them a closure to try to queue a wait_for_post_backward.
921
922
        self.children_share_process_group = True
        for n, m in self.named_modules():
923
            # `n != ""` excludes self.
924
            if n != "" and isinstance(m, FullyShardedDataParallel):
925
926
927
928
929
930
931
932
933
934
935
936
937
                # We relax the assert for non-root instance, when the nested inialized module is wrapped
                # again in FSDP later, for example after training to run inference.
                assert m._is_root is None or not m._is_root
                if m._is_root is None:
                    m._is_root = False
                    # When root instance doesn't have params, allow children instances
                    # to queue the post_backward hook.
                    #
                    # TODO (Min): we should think if we can have a empty param at the root
                    #             so that root always have a callback on the backward graph.
                    if not self._has_params:
                        assert m._queue_wait_for_post_backward_closure is None
                        m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
938
939
940
                if m.process_group != self.process_group:
                    self.children_share_process_group = False

941
942
943
944
945
946
                # if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
                # Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
                m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
                    (m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
                )

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
    def _setup_streams(self) -> None:
        """Create streams to overlap data transfer and computation."""
        if len(self._streams) > 0 or not self._is_root:
            return
        # Stream to move main FP32 params (may be on CPU) to FP16 for forward.
        self._streams["fp32_to_fp16"] = torch.cuda.Stream()
        # Stream for all-gathering parameters.
        self._streams["all_gather"] = torch.cuda.Stream()
        # Stream for overlapping grad reduction with the backward pass.
        self._streams["post_backward"] = torch.cuda.Stream()
        # Helper for bucketing reduce-scatter ops. This is also shared with
        # children instances to improve bucket utilization.
        self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
        # We share streams with all children instances, which allows them to
        # overlap transfers across the forward pass without synchronizing with
        # the default stream.
        for n, m in self.named_modules():
            if n != "" and isinstance(m, FullyShardedDataParallel):
                m._streams = self._streams
                m._reducer = self._reducer

    def _wait_for_previous_optim_step(self) -> None:
        """
        The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
        instance) needs to synchronize with the default stream to ensure the
        previous optimizer step is done.
        """
        if self.mixed_precision:
            self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
        else:
            self._streams["all_gather"].wait_stream(torch.cuda.current_stream())

    def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        self._lazy_init()

        # Start of a forward pass.
        self.training_state = TrainingState.FORWARD

985
        if self._is_root and self.mixed_precision:
986
987
988
989
990
991
992
993
994
995
996
997
998
999
            args, kwargs = cast_inputs_to_fp16(*args, **kwargs)

        # All-gather full parameters. This will also transfer FP32 parameters to
        # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
        self._rebuild_full_params()

        # Register backward hooks to reshard params and reduce-scatter grads.
        # These need to be re-registered every forward pass.
        self._register_post_backward_hooks()

        outputs = self.module(*args, **kwargs)

        if self.reshard_after_forward:
            self._free_full_params()
1000
1001
            if self.mixed_precision:
                self._free_fp16_param_shard()
1002
1003
1004
1005
1006
1007
1008
1009
1010

        # Switch to main FP32 param shard. We maintain this invariant throughout
        # the code, i.e., ``p.data == p._fp32_shard`` after each function. This
        # also ensures that after the first forward, the optimizer state will be
        # initialized with the correct dtype and (sharded) size, since optimizer
        # state is typically initialized lazily in ``optim.step()``.
        self._use_fp32_param_shard()

        # Register pre-backward hooks to all-gather the params for the backward
1011
1012
1013
1014
1015
1016
1017
        # pass (if output's grad was needed). This won't register anything if
        # we are in eval mode.
        #
        # Some model does forward pass multiple times, we need to register the
        # pre-backward hook on every output since the last output's hook has to
        # fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
        # to prevent repeated overhead from multiple hook callbacks.
1018
1019
1020
1021
1022
        outputs = self._register_pre_backward_hooks(outputs)

        # Done with a forward pass.
        self.training_state = TrainingState.IDLE

1023
1024
1025
1026
1027
1028
        # Only need to clear cache during forward. During backward, the cache is not used.
        # TODO (Min): Future PyTorch versions may provide a way to completely disable this
        #     cache. Update this when that's available.
        if self.clear_autocast_cache:
            torch.clear_autocast_cache()

1029
1030
1031
1032
        return outputs

    def _register_pre_backward_hooks(self, outputs: Any) -> Any:
        """Register pre-backward hook to run before the wrapped module's
1033
1034
1035
1036
1037
        backward. Hooks should be attached to all outputs from the forward.

        Returns:
            outputs: new outputs with hooks registered if they requires gradient.
        """
1038
1039
1040
1041
        if not torch.is_grad_enabled():
            return outputs  # don't register hooks if grad isn't enabled

        def _pre_backward_hook(*unused: Any) -> None:
1042
1043
1044
            if self._pre_backward_hook_has_run:
                return  # only run once (from multiple outputs or multiple forward passes)
            self._pre_backward_hook_has_run = True
1045
1046

            # Start of a backward pass.
1047
1048
            self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
            self.training_state = TrainingState.BACKWARD_PRE
1049
1050
1051
1052
1053
1054

            # All-gather full parameters.
            if self.reshard_after_forward:
                self._rebuild_full_params()
            else:
                self._use_full_params()
1055

1056
1057
1058
1059
            # Make sure p.grad has the correct size/device (or set it to None).
            self._prep_grads_for_backward()

        def _register_hook(t: torch.Tensor) -> torch.Tensor:
1060
1061
            if t.requires_grad:
                t.register_hook(_pre_backward_hook)
1062
1063
1064
1065
1066
1067
1068
1069
            return t

        # Attach hooks to Tensor outputs.
        outputs = apply_to_tensors(_register_hook, outputs)

        return outputs

    def _register_post_backward_hooks(self) -> None:
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
        """
        Register backward hooks to reshard params and reduce-scatter grads.

        This is called during forward pass. The goal is to attach a hook
        on each of the parameter's gradient generating function (``grad_acc``
        below) so that the hook is called *after* all gradients for that
        param are computed.

        Goals:

        1. We want the hook to fire once and only once *after* all gradients
        are accumulated for a param.
        2. If it fires more than once, we end up incorrectly shard the grad
        multiple times. (could lead to dimension too small)
        3. If it fires once but too early or doesn't fire, we leave gradients
        unsharded. (could lead to dimension too large)

        Due to multiple-pass forward, this function can be called on
        the same parameter multiple times in a single forward pass. If we register
        the hook multiple time, we end up getting called multiple times. We
        could try to get a new hook every time and delete the previous one
        registered. However, due to *unknown reason* (I have debugged it for
        a long time!), in mixed precision mode, we get two different ``grad_acc``
        objects below during different calls of this function (in the same
        forward pass). If we keep the last one, the hook end up firing too
        early. In full precision mode, we luckily get the *same* ``grad_acc``
        object, so deleting and re-registering still ensured the hook fire
        once after all gradients are generated.

        Empirically, keep the first hook register per forward pass seems to
        work the best. We do need to remove the hook at the end of the
        backward pass. Otherwise, the next forward pass will not register
        a new hook, which is needed for a new forward pass.
        """
1104
1105
        if not torch.is_grad_enabled():
            return  # don't register grad hooks if grad isn't enabled
1106
1107
1108
1109
1110
        if self._is_root:
            # This actually means that only root instance has this field
            # defined. Accidentally accessing this field will assert on all
            # other instances, giving us a nice bug checker.
            self._post_backward_callback_queued = False
1111
1112
1113
        for p in self.params:
            if p.requires_grad:
                if hasattr(p, "_shard_bwd_hook"):
1114
1115
1116
1117
1118
                    continue
                # Register a hook on the first call, empirically, autograd
                # fires it at the end for this param, which makes sense.
                p_tmp = p.expand_as(p)  # Get a grad_fn on p_tmp.
                grad_acc = p_tmp.grad_fn.next_functions[0][0]  # Gets its GradAccumulation object.
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
                handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
                p._shard_bwd_hook = (grad_acc, handle)

    @torch.no_grad()
    def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
        """
        At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
        full gradient for the local batch. The reduce-scatter op will replace
        ``param.grad`` with a single shard of the summed gradient across all
        GPUs. This shard will align with the current GPU rank. For example::

            before reduce_scatter:
                param.grad (GPU #0): [1, 2, 3, 4]
                param.grad (GPU #1): [5, 6, 7, 8]

            after reduce_scatter:
                param.grad (GPU #0): [6, 8]    # 1+5, 2+6
                param.grad (GPU #1): [10, 12]  # 3+7, 4+8

        The local GPU's ``optim.step`` is responsible for updating a single
        shard of params, also corresponding to the current GPU's rank. This
        alignment is created by :func:`_shard_parameters_`, which ensures that
        the local optimizer only sees the relevant parameter shard.
        """
1143
        # First hook callback will see PRE state. If we have multiple params,
1144
1145
1146
1147
1148
1149
1150
1151
        # then subsequent hook callbacks will see POST state. When checkpoint
        # fwd counter is used, IDLE is also possible since the pre-backward hook
        # is not triggered (see ``auto_wrap_bn`` below, we have to use
        # FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).
        if hasattr(self, "_checkpoint_fwd_counter"):
            self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE])
        else:
            self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
1152
        self.training_state = TrainingState.BACKWARD_POST
1153
1154
        if param.grad is None:
            return
1155

1156
        if param.grad.requires_grad:
1157
1158
1159
1160
1161
1162
1163
1164
            raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients")

        # If this is a checkpointed module, we check if the following
        # counter reaches 0. If not, it is not the final backward call
        # for this module yet. Therefore, we early return in that case.
        if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
            if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
                return
1165

1166
        if self._require_backward_grad_sync or self.reshard_after_forward:
1167
            # Free full params. As a special case, we don't free the full params
1168
1169
1170
            # when in a ``no_sync`` context (as inversely indicated by
            # ``self._require_backward_grad_sync``), since the params will not
            # get updated before the next forward.
1171
1172
            self._free_full_params([param])

1173
1174
1175
1176
1177
1178
        if self.mixed_precision:
            # This is a no-op if reshard_after_forward is True, since we already
            # free the param shard when rebuilding the full params in the
            # pre_backward_hook.
            self._free_fp16_param_shard([param])

1179
1180
1181
        # Switch to FP32 shard after backward.
        self._use_fp32_param_shard([param])

1182
1183
1184
1185
1186
1187
1188
        # (try to) Enqueue a callback at the end of the backward pass to ensure that all
        # post-backward work has finished. We only need one callback and all instances
        # of FSDP (root and children) make this attempt here to queue to ensure it is queued
        # no matter which instance(s) has(have) params.
        assert self._queue_wait_for_post_backward_closure is not None or not self._is_root
        if self._queue_wait_for_post_backward_closure is not None:
            self._queue_wait_for_post_backward_closure()
1189

1190
        if not self._require_backward_grad_sync:
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
            return

        # Wait for all work in the current stream to finish, then start the
        # reductions in post_backward stream.
        self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self._streams["post_backward"]):
            orig_grad_data = param.grad.data

            if self.mixed_precision and self.fp32_reduce_scatter:
                # Cast grad to FP32.
                param.grad.data = param.grad.data.to(param.dtype)

1203
            if self.gradient_predivide_factor > 1:
1204
                # Average grad by world_size for consistency with PyTorch DDP.
1205
                param.grad.data.div_(self.gradient_predivide_factor)
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230

            callback_fn = functools.partial(self._post_reduction_hook, param)
            if param._is_sharded:
                assert param._is_sharded
                assert self._reducer is not None
                grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
                self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
            else:
                # Currently the only way for _is_sharded to be False is if
                # world_size == 1. This could be relaxed in the future, in which
                # case grads should be all-reduced here.
                assert self.world_size == 1
                callback_fn(param.grad.data)

            # After _post_backward_hook returns, orig_grad_data will eventually
            # go out of scope, at which point it could otherwise be freed for
            # further reuse by the main stream while the div/reduce_scatter/copy
            # are underway in the post_backward stream. See:
            # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
            orig_grad_data.record_stream(self._streams["post_backward"])

    def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
        """Hook to call on each param after the reduce-scatter."""
        assert torch.cuda.current_stream() == self._streams["post_backward"]
        assert param.grad is not None
1231
        self.assert_state(TrainingState.BACKWARD_POST)
1232
        param.grad.data = reduced_grad
1233
1234
1235
        if self.gradient_postdivide_factor > 1:
            # Average grad by world_size for consistency with PyTorch DDP.
            param.grad.data.div_(self.gradient_postdivide_factor)
1236
1237
1238
1239
        # Cast grad to param's dtype (typically FP32). Note: we do this
        # before the move_grads_to_cpu step so that this entire hook remains
        # non-blocking. The downside is a bit more D2H transfer in that case.
        if self.mixed_precision:
1240
            orig_param_grad_data = param.grad.data
1241
            param.grad.data = param.grad.data.to(dtype=param.data.dtype)
1242
1243
            # Don't let this memory get reused until after the transfer.
            orig_param_grad_data.record_stream(torch.cuda.current_stream())
1244
1245
1246
        # Optionally move gradients to CPU, typically used if one is running
        # the optimizer on the CPU.
        if self.move_grads_to_cpu:
1247
1248
1249
            param._cpu_grad.copy_(param.grad.data, non_blocking=False)
            # Don't let this memory get reused until after the transfer.
            param.grad.data.record_stream(torch.cuda.current_stream())
1250
1251
            param.grad.data = param._cpu_grad

1252
1253
    def _queue_wait_for_post_backward(self) -> None:
        """Try to queue a `wait_for_post_backward` callback.
1254
1255
1256
1257

        Only called on root and only queue one callback. But can be called by
        children FSDPs via a closure in case the root instance doesn't own any
        params.
1258
1259
        """
        assert self._is_root
1260
        self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
1261
1262
1263
1264
        if not self._post_backward_callback_queued:
            self._post_backward_callback_queued = True
            Variable._execution_engine.queue_callback(self._wait_for_post_backward)

1265
1266
    @torch.no_grad()
    def _wait_for_post_backward(self) -> None:
1267
        """Wait for post-backward to finish. Only called on root instance."""
1268
        assert self._is_root
1269
1270
1271
1272
1273
        if self._has_params:
            self.assert_state(TrainingState.BACKWARD_POST)
        else:
            self.assert_state(TrainingState.BACKWARD_PRE)

1274
1275
1276
1277
1278
1279
1280
1281
1282
        if self._require_backward_grad_sync:
            # Flush any unreduced buckets in the post_backward stream.
            with torch.cuda.stream(self._streams["post_backward"]):
                assert self._reducer is not None
                self._reducer.flush()
            torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
            if self.move_grads_to_cpu:
                # Wait for the non-blocking GPU -> CPU grad transfers to finish.
                torch.cuda.current_stream().synchronize()
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299

        # A backward pass is done, clean up below.

        # Free reducer buffers.
        if self._reducer is not None:
            self._reducer.teardown()

        def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
            """Helper used below on all fsdp modules."""
            for p in fsdp_module.params:
                if p.requires_grad:
                    if hasattr(p, "_shard_bwd_hook"):
                        assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
                        p._shard_bwd_hook[1].remove()
                        delattr(p, "_shard_bwd_hook")

        # Update root and nested FSDP's hooks and flags.
1300
1301
        for m in self.modules():  # includes self
            if isinstance(m, FullyShardedDataParallel):
1302
                _remove_shard_bwd_hook(m)
1303
                m._pre_backward_hook_has_run = False
1304
                if m._has_params:
1305
1306
1307
1308
1309
1310
                    if any(p.requires_grad for p in m.params):
                        m.assert_state(TrainingState.BACKWARD_POST)
                    else:
                        # Unlikely case, should only happens if `m` has params but none of the
                        # params has `requires_grad==True`.
                        m.assert_state(TrainingState.IDLE)
1311
1312
                else:
                    m.assert_state(TrainingState.BACKWARD_PRE)
1313
                m.training_state = TrainingState.IDLE
1314
1315

    @torch.no_grad()
1316
    def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
1317
1318
1319
1320
        """
        Gather all shards of params.

        Args:
1321
1322
            force_full_precision (bool, Optional): by default params will be gathered
                in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
1323
                ``True``, in which case they will be gathered in full precision
1324
1325
                (e.g., FP32), possibly in fresh storage. The parameter that's being
                rebuilt will end up in full precision as well.
1326
1327

        Returns:
1328
            A list of tuples, where the first element is the full-sized param
1329
            and the second element is a bool indicating if it's safe for the
1330
            caller to free the full-sized param. This will be ``None`` if
1331
            ``force_full_precision=False`` and the full params are already gathered.
1332
1333
        """
        output_tensors: List[Tuple[torch.Tensor, bool]] = []
1334
1335

        def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
1336
1337
1338
1339
1340
1341
1342
            """
            Helper function to update p.data pointer.

            Args:
                custom_output_tensor (torch.Tensor, Optional): if not None, this
                tensor contains the data we just gathered.
            """
1343
1344
1345
1346
1347
            if custom_output_tensor is not None:
                assert p._is_sharded
                p.data = custom_output_tensor
                output_tensors.append((p.data, True))
            elif not p._is_sharded:
1348
                if self.mixed_precision and not force_full_precision:
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
                    p.data = p._fp16_shard
                    output_tensors.append((p.data, True))
                else:
                    # Here p.data == p._fp32_shard, so it's not safe to free.
                    output_tensors.append((p.data, False))
            else:
                p.data = p._full_param_padded
                output_tensors.append((p.data, True))
            # Trim any padding and reshape to match original size.
            p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

        # Early exit if we already have full params and don't need full precision.
1361
        if self.has_full_params and not force_full_precision:
1362
1363
1364
1365
1366
1367
            for p in self.params:
                update_p_data()
            return output_tensors

        self.has_full_params = True

1368
        with torch.cuda.stream(self._streams["all_gather"]):
1369
            if self.mixed_precision and not force_full_precision:
1370
1371
1372
                self._cast_fp32_param_shards_to_fp16()

            for p in self.params:
1373
                if not p._is_sharded:  # e.g., when world_size == 1
1374
                    update_p_data()
1375
                else:
1376
                    # If self.cpu_offload and force_full_precision, we need to cast
1377
1378
1379
1380
1381
                    # the FP32 CPU param to CUDA for the all-gather.
                    p_data = p.data.to(p._full_param_padded.device)

                    p_size = p._full_param_padded.size()
                    assert p_size.numel() % self.world_size == 0
1382
1383
1384
1385
1386
                    if self.mixed_precision and force_full_precision:
                        # Allocate fresh tensor in full precision since we are in
                        # mixed precision and full precision rebuild is asked.
                        output_tensor = p_data.new_zeros(p_size)
                    else:
1387
1388
1389
1390
                        if p._full_param_padded.storage().size() != p_size.numel():
                            # Allocate based on full size from all shards.
                            alloc_storage_(p._full_param_padded, size=p_size)
                        output_tensor = p._full_param_padded
1391

1392
1393
1394
                    # Fill output_tensor with (p.data for each shard in self.world_size)
                    chunks = list(output_tensor.chunk(self.world_size))
                    dist.all_gather(chunks, p_data, group=self.process_group)
1395

1396
1397
                    # Set p.data = output_tensor (with padding trimmed)
                    update_p_data(output_tensor)
1398

1399
                    if self.mixed_precision and not force_full_precision:
1400
                        self._free_fp16_param_shard([p])
1401
        torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
1402
        return output_tensors
1403
1404
1405

    @torch.no_grad()
    def _use_full_params(self) -> None:
1406
1407
        """
        Switch p.data pointers to use the full params.
1408

1409
        Note: this assumes full params are already gathered.
1410
        """
1411
        assert self.has_full_params
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
        for p in self.params:
            if not p._is_sharded:
                if self.mixed_precision:
                    assert p._fp16_shard.storage().size() != 0
                    p.data = p._fp16_shard
            else:
                assert p._full_param_padded.storage().size() != 0
                p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)

    @torch.no_grad()
    def _prep_grads_for_backward(self) -> None:
        """Make sure p.grad has the correct size/device, otherwise set it to None."""
        for p in self.params:
            if p.grad is not None and (p.grad.size() != p._orig_size or p.grad.device != p.data.device):
                p.grad = None

    @torch.no_grad()
    def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
        """Free up storage for full parameters."""
        if params is None:
            params = self.params
1433
        self.has_full_params = False
1434
        self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
1435
1436
        with torch.cuda.stream(self._streams["all_gather"]):
            for p in params:
1437
                if not p._is_sharded:  # e.g., world_size == 1
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
                    if self.mixed_precision:
                        self._free_fp16_param_shard([p])
                    continue
                # There may be external references to the Tensor Storage that we
                # can't modify, such as references that are created by
                # ctx.save_for_backward in the forward pass. Thus when we
                # unshard parameters, we should reuse the original Tensor
                # Storage object and unshard it in-place. For now, just resize
                # the Storage to 0 to save memory.
                free_storage_(p._full_param_padded)

    @torch.no_grad()
    def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
        """Use FP32 shard for a list of params."""
        if params is None:
            params = self.params
        for p in params:
            p.data = p._fp32_shard

    @torch.no_grad()
    def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
        """Cast FP32 param shard to FP16 for a list of params."""
        if params is None:
            params = self.params
        with torch.cuda.stream(self._streams["fp32_to_fp16"]):
            for p in params:
                assert p._fp16_shard is not None
                alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
                p._fp16_shard.copy_(
                    # If cpu_offload is True, this will be non-blocking because
                    # _fp32_shard is pinned, otherwise it's a no-op.
                    p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
                )
                p.data = p._fp16_shard
        torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

    @torch.no_grad()
    def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
        """Free storage for FP16 shards for a list of params."""
        if params is None:
            params = self.params
        current_stream = torch.cuda.current_stream()
        for p in params:
            if p._fp16_shard is not None:
1482
                # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
1483
1484
1485
1486
                # free it until the work in the current stream completes.
                p._fp16_shard.record_stream(current_stream)
                free_storage_(p._fp16_shard)

1487
    def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
1488
        """Assert we are in the given state."""
1489
1490
1491
1492
1493
1494
1495
1496
1497
        # Since assert can be turned off and this error checking
        # is really important, we use explicit error checking
        # and raise a ValueError if needed.
        if isinstance(state, TrainingState):
            state = [state]
        if self.training_state not in state:
            msg = f"expected to be in states {state} but current state " f"is {self.training_state}"
            # In case we are failing in the context of autograd hook, asserting
            # may not generate useful msg. So, let's print it to be sure.
Min Xu's avatar
Min Xu committed
1498
            if self.rank == 0:
1499
1500
                print(f"Asserting FSDP instance is: {self}")
                print(f"ERROR: {msg}")
Min Xu's avatar
Min Xu committed
1501
                traceback.print_stack()
1502
            raise ValueError(msg)
1503

1504
1505
    def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
        """Collect [x.numel_padded_per_param for x in self._fsdp_instances] from teach rank."""
1506
        dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
1507
        world_pad_info: List[List[List[int]]] = []  # this will contain values from the whole world.
1508
1509
        for rank in range(self.world_size):
            if rank == self.rank:
1510
                pad_info = [m.numel_padded_per_param for m in self._fsdp_instances]
1511
            else:
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
                pad_info = dummy_tensor  # type: ignore
            pad_info = broadcast_object(
                pad_info, src_rank=rank, group=self.process_group, dist_device=self.compute_device
            )
            if self.rank == 0:
                world_pad_info.append(pad_info)  # type: ignore
        return world_pad_info

    def _gather_optim_state(
        self, sd_state: Dict[int, Dict[str, Any]]
    ) -> Tuple[Dict[int, Dict[str, List]], Dict[int, Dict[str, List]]]:
        """For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
        gathered_state: Dict[int, Dict[str, List[Any]]] = {}
        singleton_state: Dict[int, Dict[str, List[Any]]] = {}  # Dimensionless tensor
        for k, v in sd_state.items():
            gathered_state[k] = {}
            singleton_state[k] = {}
            desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size()  # type: ignore
            buffer = None  # for sharded tensors
            singleton_buffer = None  # for singleton tensors
            for buffer_name, t in v.items():
1533
1534
1535
                if torch.is_tensor(t):
                    t = t.to(self.compute_device)

1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
                if ou.is_singleton_tensor(t):
                    if singleton_buffer is None:
                        singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size))
                    dist.all_gather(singleton_buffer, t, group=self.process_group)
                    if self.rank == 0:
                        singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer]
                        assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0])
                elif torch.is_tensor(t):
                    if buffer is None:
                        buffer = list(t.new_zeros(*desired_buffer_size).chunk(self.world_size))
                    dist.all_gather(buffer, t, group=self.process_group)
                    if self.rank == 0:
                        gathered_state[k][buffer_name] = [x.cpu() for x in buffer]
                elif self.rank == 0:  # Add non tensor state
                    gathered_state[k][buffer_name] = [t]

        return gathered_state, singleton_state

    def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored: Dict) -> Optional[Dict[str, Any]]:
1555
1556
1557
1558
        """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
        sharded properties are not exposed. Multiple parameter groups are not yet supported.

        This should be called only on the root FSDP instance.
1559
        Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
1560
1561

        Args:
1562
1563
            optim (Optimizer): an optimizer instance for this FSDP rank. Its state_dict is
                        used in the consolidation. However, its state is not modified.
1564
1565

        Returns:
1566
1567

            * A dict with four entries (On rank zero, other workers return ``None``)
1568
1569
                * state - a dict holding gathered optimization state, 1 entry per unflat parameter
                * param_groups - a dict containing the 1 parameter group
1570
1571
                * param_id_map - global (unflat) to local (flat) id mapping
                * uncollected_local_ids - keys in the state dict that were not broadcast
1572
1573
1574
1575

        """
        if not self.flatten_parameters:
            raise NotImplementedError("optim state dict requires flatten_parameters=True")
1576
1577
1578
1579
1580
1581
1582
1583
1584

        self._lazy_init()
        sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
        assert set(sd.keys()) == {"param_groups", "state"}, f'{set(sd.keys())} != {"param_groups", "state"}'
        assert len(sd["param_groups"]) == 1, "Param groups are not supported"
        # We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
        state, singleton_state = self._gather_optim_state(sd.pop("state"))
        pad_info = self._broadcast_pad_info_to_r0()
        if self.rank != 0:
1585
1586
            return None
        # Unify the shard states by concatenating tensors and unflattening params
1587
        new_state_dict = ou.build_unflat_state_dict(
1588
            self._fsdp_instances, pad_info, state, singleton_state, self.uncollected_opt_state, sd["param_groups"]
1589
1590
1591
        )
        self.uncollected_opt_state = {}
        assert "uncollected_local_ids" in new_state_dict
1592
1593
1594
1595
1596
1597
1598
        return new_state_dict

    @property
    def _fsdp_instances(self) -> List[nn.Module]:
        """Returns all fsdp modules in self.modules() including self."""
        return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

1599
1600
1601
1602
    def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
        uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
        new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
        if self.rank == 0:
1603
1604
1605
1606
1607
1608
            # Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
            self.uncollected_opt_state = {
                k: recursive_copy_to_device(v, non_blocking=False, device=torch.device("cpu"))
                for k, v in osd["state"].items()
                if k in uncollected_ids
            }
1609
1610
1611
1612
1613

        pg = copy.deepcopy(osd["param_groups"])
        new_dct["param_groups"] = pg
        return new_dct

1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
    def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
        """Get the portion of the optimizer state dict associated with the shard

        This can be used to get the right sharded optimizer state to be loaded
        into the sharded optimizer for this FSDP rank.

        Args:
            full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint.

        Returns:
            (dict): a shard of the optimizer state.
        """
        # Assert nesting is the same as it was at save time
        instance_list = self._fsdp_instances
        ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
1629
        ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
1630
1631
        if self.flatten_parameters:
            full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
1632
1633
1634
1635
            assert len(full_optim_state_dict["state"]) in (
                0,
                len(instance_list),
            ), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
1636
1637
1638
1639

        # get the portion of dict associated with the shard, in place
        for id, s in full_optim_state_dict["state"].items():
            for k, v in s.items():
1640
                if torch.is_tensor(v) and id not in ids_not_to_shard:
1641
                    v_shard, _ = self._get_shard(v)
1642
1643
1644
1645
                elif isinstance(v, list) and ou.is_singleton_tensor(v[0]):
                    # if we are resuming on larger world size, take first entry
                    v_shard = v[0] if self.rank >= len(v) else v[self.rank]
                    assert ou.is_singleton_tensor(v_shard)
1646
                else:
1647
                    v_shard = v  # don't shard entries that are not tensors
1648
1649
1650
1651
                full_optim_state_dict["state"][id][k] = v_shard

        return full_optim_state_dict

1652
    def _print_r0(self, msg: str, restart: bool = False) -> None:
1653
        """Debugging utility to print memory usage stats nicely on rank 0"""
1654
1655
        if restart:
            self._tstart = time.time()
1656
1657
        if self.rank == 0:
            gb_denom = 1024 ** 3
1658
            logging.info(
1659
1660
1661
                f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
            )

1662

1663
1664
def _get_default_cuda_device(module: nn.Module) -> torch.device:
    """Try to infer CUDA device from module parameters."""
1665
1666
1667
1668
1669
1670
1671
1672
    try:
        compute_device = next(module.parameters()).device
        if compute_device.type == "cuda":
            return compute_device
    except StopIteration:
        pass
    # Fall back to current CUDA device
    return torch.device("cuda")
1673
1674


1675
1676
1677
1678
1679
@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
    """
    Cast any Tensors in *args or **kwargs to FP16.
    """
1680
1681
1682
1683
1684
1685
1686

    def fn(x: torch.Tensor) -> torch.Tensor:
        if x.dtype is torch.float32:
            return x.half()
        return x

    return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704


def free_storage_(data: torch.Tensor) -> None:
    """Free underlying storage of a Tensor."""
    if data.storage().size() > 0:
        # Since we're modifying the Tensor's Storage directly, make sure the Tensor
        # is the sole occupant of the Storage.
        assert data.storage_offset() == 0
        data.storage().resize_(0)


@torch.no_grad()
def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
    """Allocate storage for a tensor."""
    if data.storage().size() == size.numel():  # no need to reallocate
        return
    assert data.storage().size() == 0
    data.storage().resize_(size.numel())
1705
1706
1707


def _post_state_dict_hook(
1708
    module: FullyShardedDataParallel, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
1709
) -> "OrderedDict[str, torch.Tensor]":
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
    # Assuming we are in a ``summon_full_params()`` context, we need to clone
    # each tensor so that it does not get freed (in-place) when the context
    # exits. At the same time, this hook can be called multiple times
    # recursively, so we need to make sure that we only clone each tensor at
    # mostonce. Thus we add an attribute on the tensor called "_has_been_cloned"
    # which keeps track of tensors that are no longer at risk of being freed.
    for key in state_dict.keys():
        if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
            continue
        if state_dict[key].device.type != module.state_dict_device.type:
            state_dict[key] = state_dict[key].to(device=module.state_dict_device)
            state_dict[key]._has_been_cloned = True
        elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
            # We copy the state_dict since full param will be freed after we
            # exit the ``summon_full_params()`` context.
1725
            state_dict[key] = state_dict[key].clone()
1726
            state_dict[key]._has_been_cloned = True
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736

    # Remove "_fsdp_wrapped_module." prefix
    replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
    return state_dict


def _pre_load_state_dict_hook(
    state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
    replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
Min Xu's avatar
Min Xu committed
1737
1738
1739
1740
1741
1742
1743


########################################################################################
# Below are APIs used together with FSDP, but not directly part of FSDP.
########################################################################################


1744
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ProcessGroup = None) -> nn.Module:
Min Xu's avatar
Min Xu committed
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
    """
    Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
    to sync BN is used and the outer FSDP is flattening.

    We put BN in is own full precision, unflatten, single GPU group FSDP.  Note, SyncBNs still have
    a group size == world_size. The input and output for BN are still FP16 in mixed precision mode.
    See ``keep_batchnorm_fp32`` here: https://nvidia.github.io/apex/amp.html

    This needs to be done at each rank, like models being wrapped by FSDP at each rank.

    Args:
        module (nn.Module):
            The model (or part of the model) in which BN to be pre-wrapped.
1758
1759
1760
        single_rank_pg (bool):
            If true, put BNs in a single-rank process group. Default False.
            This might be needed for Apex sync BN support. Still under construction.
Min Xu's avatar
Min Xu committed
1761
1762
1763
1764
1765
1766
1767
1768

    Returns:
        Processed module, where BNs are wrapped with a special FSDP instance.
    """

    def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool:
        is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
        if recurse:
Min Xu's avatar
Min Xu committed
1769
1770
1771
            return not isinstance(
                module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)  # type: ignore
            )
Min Xu's avatar
Min Xu committed
1772
        else:
Min Xu's avatar
Min Xu committed
1773
1774
1775
            return is_bn and not isinstance(
                module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)  # type: ignore
            )
Min Xu's avatar
Min Xu committed
1776

1777
1778
1779
1780
1781
    pg = None
    if single_rank_pg:
        # No sharding with this single member group.
        my_rank = dist.get_rank()
        pg = dist.new_group(ranks=[my_rank])
1782
1783
    else:
        pg = process_group
1784

Min Xu's avatar
Min Xu committed
1785
1786
    fsdp_config = {
        "wrapper_cls": FullyShardedDataParallel,
1787
        "process_group": pg,
Min Xu's avatar
Min Xu committed
1788
1789
        "mixed_precision": False,  # Keep the weights in FP32.
        "flatten_parameters": False,  # Do not flatten.
1790
1791
1792
1793
        # Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
        # **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
        # within the checkpoint's outer backward when multiple forward passes are used.
        "reshard_after_forward": False,
1794
1795
        # No bucketing or small bucketing should be enough for BNs.
        "bucket_cap_mb": 0,
Min Xu's avatar
Min Xu committed
1796
1797
1798
1799
    }

    with enable_wrap(wrap_bn_only_policy, **fsdp_config):
        return auto_wrap(module)