"vscode:/vscode.git/clone" did not exist on "aca61bc0b310b7d63dfb2654dc76aafc25543136"
fully_sharded_data_parallel.py 124 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
8
from dataclasses import dataclass
9
10
from enum import Enum, auto
import functools
11
import logging
12
from math import inf
13
import os
14
import tempfile
15
import time
Min Xu's avatar
Min Xu committed
16
import traceback
17
import typing
18
19
20
21
22
23
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generator,
24
    Iterator,
25
26
27
28
29
30
31
32
33
    List,
    Mapping,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    Union,
    cast,
)
34
35
36
37
38
39
40

import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
import torch.nn.functional as F
41
from torch.nn.parameter import Parameter
42
43

from fairscale.nn.misc import FlattenParamsWrapper
44
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
45
from fairscale.utils.containers import apply_to_tensors
46
from fairscale.utils.parallel import (
47
    ProcessGroupName,
48
49
50
51
52
    chunk_and_pad,
    enable_pytorch_sync_bn,
    get_process_group_cached,
    validate_process_group,
)
53
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
54
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
55
from fairscale.utils.state_dict import replace_by_prefix_
56

57
58
from . import fsdp_optim_utils as ou

59
60
if TYPE_CHECKING:
    from collections import OrderedDict  # noqa: F401
61
62
63
64
65
# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
    enable_nccl_base_collectives = False
else:
    enable_nccl_base_collectives = True
66

67
68
try:
    import fairscale.experimental.nn.ssd_offload as ssd_offload
69
    from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
70
71
72
73
74
75
76

    import_ssd_offload = True
except ImportError:
    # The latest nightly PyTorch version required
    import_ssd_offload = False
    pass

77
78
79
80
81
82

class TrainingState(Enum):
    """
    Simple enum to indicate what state FSDP is in. Used for asserting
    to make sure APIs are called in the correct state.

83
84
85
86
87
88
89
    ..note::

        BACKWARD_PRE and BACKWARD_POST states are used to ensure we
        receives backward hooks in the correct order. It is used to catch
        unexpected order of hooks being called (likely due to our
        hook registration logic or autograd engine logic changes).

90
91
92
93
94
95
96
97
98
99
100
    TODO (Min): It would be nice to capture the stepping state as well.
        Maybe we can use the model.zero_grad() call, but not sure if it
        is called if optim.zero_grad() is used instead.
        It would be nice to have clear state transition be explicit like:

        zero_grad -> fwd -> bwd -> optionally accum grad by repeating
        fwd/bwd -> stepping -> loop back to zero_grad
    """

    IDLE = auto()
    FORWARD = auto()
101
102
    BACKWARD_PRE = auto()
    BACKWARD_POST = auto()
103
    SUMMON_FULL_PARAMS = auto()
104
105


106
107
108
109
110
111
112
113
# Data classes containing FSDP parameter constructs

# Offload config for specifying SSD options (initially at least)
@dataclass
class OffloadConfig:
    """Class for specifying all arguments related to offloading parameters."""

    # Offload type: currently only supports: "ssd_offload"
114
    offload_type: Optional[str] = None
115
    # Path to the directory for storing parameters offloaded to disk.
116
    dir: Optional[str] = None
117
118


119
120
121
122
class FullyShardedDataParallel(nn.Module):
    """
    A wrapper for sharding Module parameters across data parallel workers. This
    is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
123
    FullyShardedDataParallel is commonly shorten to FSDP.
124
125
126
127

    .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
    .. _DeepSpeed: https://www.deepspeed.ai/

Min Xu's avatar
Min Xu committed
128
    Pseudo-code usage::
129

130
        import torch
131
        from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
Min Xu's avatar
Min Xu committed
132

Myle Ott's avatar
Myle Ott committed
133
        torch.cuda.set_device(device_id)
134
        sharded_module = FSDP(my_module)
135
136
137
138
139
140
141
142
        optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
        x = sharded_module(x, y=3, z=torch.Tensor([1]))
        loss = x.sum()
        loss.backward()
        optim.step()

    It is also possible to shard individual layers separately and have an outer
    wrapper handle any leftover parameters. This can be helpful to further
Myle Ott's avatar
Myle Ott committed
143
144
145
    reduce GPU memory usage, reduce system memory usage when initializing large
    models and to improve training speed by overlapping the all-gather step
    across the forward pass. For example::
146

147
        import torch
Min Xu's avatar
Min Xu committed
148
        from fairscale.nn.wrap import wrap, enable_wrap, auto_wrap
Sam Shleifer's avatar
Sam Shleifer committed
149
        from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
Min Xu's avatar
Min Xu committed
150
151
152
153
        from fairscale.utils.testing import dist_init, teardown, rmf

        result = dist_init(0, 1, "/tmp/t1", "/tmp/t2")
        assert result
154
155
        fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
        with enable_wrap(**fsdp_params):
Min Xu's avatar
Min Xu committed
156
157
            l1 = wrap(torch.nn.Linear(5, 5))
            assert isinstance(l1, FSDP)
Sam Shleifer's avatar
Sam Shleifer committed
158
159
            # Wraps layer in FSDP by default if within context
            # Separately Wraps children modules with more than 1e8 params
Min Xu's avatar
Min Xu committed
160
161
162
163
164
165
166
167
168
            large_tfmr = torch.nn.Transformer(d_model=2048, num_encoder_layers=12,
                                              num_decoder_layers=12)
            l2 = auto_wrap(large_tfmr)
            assert isinstance(l2.encoder, FSDP)
            assert isinstance(l2.decoder, FSDP)
            print(l2)  # You can print the model to examine FSDP wrapping.
        teardown()
        rmf("/tmp/t1")
        rmf("/tmp/t2")
169

Myle Ott's avatar
Myle Ott committed
170
171
172
173
174
175
    .. warning::

        The optimizer must be initialized *after* the module has been wrapped,
        since FSDP will shard parameters in-place and this will break any
        previously initialized optimizers.

176
177
178
179
180
181
182
    .. warning::

        If you wrap every parameter inside a nested FSDP and leaving the outer
        FSDP empty without any parameter, checkpointing activation may trigger
        an assert on the backward pass. The solution is to leave some parameters
        to the outer FSDP.

183
184
185
186
187
188
    .. warning::

        If activation checkpointing is used with FSDP, it is strongly encouraged
        to use ``checkpoint_wrapper`` function from FairScale instead of the
        ``checkpoint`` function from PyTorch.

189
    Args:
Min Xu's avatar
Min Xu committed
190
        module (nn.Module):
191
            module to be wrapped with FSDP.
Min Xu's avatar
Min Xu committed
192
193
        process_group (Optional):
            process group for sharding
194
195
196
197
198
199
200
        process_group_reduce_scatter (Optional):
            process group for reduce scatter
            it defaults to ProcessGroupName.reduce_scatter. A seperate process group is initialized and assigned to the reduce_scatter operation. And the
            reduce_scatter operation overlaps with other operations in the backward propagation
            If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens.
            To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter
            operation uses the same process group with the default group.
201
202
            If reduce scatter process group size is differnt with the default process group size, the reduce_scatter
            operation rolls back to use the same process group with the default process group.
Min Xu's avatar
Min Xu committed
203
        reshard_after_forward (bool, Optional):
Myle Ott's avatar
Myle Ott committed
204
205
206
            if ``True``, reshard parameters after the forward pass. This saves
            memory but slows training. This is only relevant when resharding
            individual layers.
207
208
209
210
211
212
213
214
215
216
        disable_reshard_on_root (bool, Optional):
            If ``True``, ``reshard_after_forward`` will be set to ``False`` if the module is a
            FSDP root module to improve performance. For some cases, we do not reshard the full
            parameters of an FSDP root module since those parameters are needed immediately for the
            backward pass.
            If ``False``, the performance will be lower, but it is needed because it helps to
            save memory. Consider a case that an FSDP root module is a submodule of a model.
            Backward pass may not start immediate after the FSDP root module finishes its forward.
            So, reshard the parameters for the FSDP root modules can help to save memory in this case.
            Default: True.
Min Xu's avatar
Min Xu committed
217
        mixed_precision (bool, Optional):
Myle Ott's avatar
Myle Ott committed
218
219
220
            if ``True``, inputs, activations and gradients will be kept in FP16;
            computation and communication will occur in FP16; and a (sharded)
            master copy of the model weights will be maintained in FP32.
Min Xu's avatar
Min Xu committed
221
        fp32_reduce_scatter (bool, Optional):
Myle Ott's avatar
Myle Ott committed
222
223
            if ``True``, then reduce-scatter gradients in FP32. This is only
            relevant when *``mixed_precision``* is ``True``.
Min Xu's avatar
Min Xu committed
224
        flatten_parameters (bool, Optional):
Myle Ott's avatar
Myle Ott committed
225
226
            if ``True``, flatten parameters into a single contiguous tensor,
            which improves training speed.
227
        move_params_to_cpu (bool, Optional):
228
            if ``True``, offload params to CPU.
Min Xu's avatar
Min Xu committed
229
        compute_dtype (torch.dtype, Optional):
Myle Ott's avatar
Myle Ott committed
230
231
232
            dtype for full parameters for computation. This defaults to
            ``torch.float32`` unless *``mixed_precision``* is set, in which case
            it defaults to ``torch.float16``.
233
234
        buffer_dtype (torch.dtype, Optional):
            dtype for buffers for computation. This defaults to ``compute_dtype``.
Min Xu's avatar
Min Xu committed
235
        move_grads_to_cpu (bool, Optional):
Myle Ott's avatar
Myle Ott committed
236
237
            move gradient shard to CPU after reduction. This is useful when
            combined with CPU-based optimizers. It defaults to the value of
238
            *``move_params_to_cpu``*.
Min Xu's avatar
Min Xu committed
239
        bucket_cap_mb (int, Optional):
Myle Ott's avatar
Myle Ott committed
240
            FSDP will bucket parameters so that gradient reduction can
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            be more efficient for small parameters.
            ``bucket_cap_mb`` controls the bucket size in MegaBytes (MB). Buckets
            are sub-divided based on world_size, so the max shard size is roughly
            ``bucket_cap_mb / world_size``. There is one bucketer (with potentially
            multiple ``bucket_cap_mb`` sized buffers shared by all FSDP instances.
            Large gradient tensors are directly reduced without using the buffers.
            The buffers are there to reduce communication overhead for small tensors.
            Overlapping with computation happens due to use of a different CUDA stream
            than the computation CUDA stream. The total memory overhead per buffer is around
            ``bucket_cap_mb / world_size * (world_size + 1)``.
            The buffers are allocated during the backward pass and freed at the end
            of the backward pass to save more memory for other phases of the
            training process.
            Note, the memory vs. speed tradeoff of bucket size is very different
            from that of the DDP engine. In DDP, the buffer size ``1MB + n*cap_mb``,
            until n is big enough to cover the entire model size. The order
            of which buffer is ready there is more rigid and DDP requires all
            gradients to be computed in the backward. In FSDP, the buffer size
            does not change with model size (it changes based on number of
            <dtype, device, process_group> tuples) and gradient ready order matters
            little since FSDP has a final flush call that ensures everything is reduced
            and not all gradients need to be upfront known. Overlapping with compute is
            done differently too.
            Values <= 0 disable bucketing.
Myle Ott's avatar
Myle Ott committed
265
            Default: 25.
266
267
268
269
270
        compute_device (torch.device, Optional):
            device for computation. If not given and module params are on a CUDA
            device, the param's device will be used. If not given and module
            params are on CPU, then the current CUDA device (as indicated by
            ``torch.cuda.current_device()`` will be used.
271
272
273
274
275
276
        no_broadcast_optim_state: (bool, Optional)
            do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
            If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
            with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
            where all but a few parameters can fit on one node.
            Default: False
277
278
279
280
        state_dict_device (torch.device, Optional):
            device for parameters returned by :func:`state_dict`. If not given,
            this will default to ``compute_dtype``. Note that only the device
            type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
281
282
283
284
285
286
        clear_autocast_cache (bool):
            When using mixed precision training with `torch.amp.autocast`, if the model weights
            are in FP32, autocast maintains a cache for downcasted weights. The cache can cause
            GPU OOM during the forward pass. Setting this flag to true will help clearing this
            cache as inner FSDP instances finish part of the forward pass to save GPU memory.
            Default: False
287
288
289
290
291
        force_input_to_fp32 (bool):
            Set to ``True`` to force input floating point tensors to be FP32 (if they are FP16)
            when the FSDP instance is in full precision mode. This helps avoid issues of running
            SyncBatchNorm with AMP and checkpoint_wrapper.
            Default: False
292
293
294
        verbose (bool):
            Set this to ``True`` to turn on verbose output for model's string representation.
            Default: False
295
        cpu_offload (bool, Optional):
296
297
            if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
            *``move_params_to_cpu``* in an upcoming release.
298
299
300
301
        offload_config (OffloadConfig):
            The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
            other required knobs when offloading parameters from GPU. Currently the OffloadConfig
            only supports specifying SSD offload as an option. Note: This is an experimental feature.
302
303
304
305
306
        state_dict_on_rank_0_only (bool):
            When set to ``True``, ``model.state_dict()`` will only returns full state dict on
            rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
            skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
            Default: False
307
308
309
310
311
312
    """

    def __init__(
        self,
        module: nn.Module,
        process_group: Optional[ProcessGroup] = None,
313
314
        # The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
        process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
315
        reshard_after_forward: bool = True,
316
        disable_reshard_on_root: bool = True,
317
318
319
        mixed_precision: bool = False,
        fp32_reduce_scatter: bool = False,
        flatten_parameters: bool = True,
320
        move_params_to_cpu: bool = False,
321
        compute_dtype: Optional[torch.dtype] = None,
322
        buffer_dtype: Optional[torch.dtype] = None,
323
324
        move_grads_to_cpu: Optional[bool] = None,
        bucket_cap_mb: int = 25,
325
        compute_device: Optional[torch.device] = None,
326
        no_broadcast_optim_state: Optional[bool] = False,
327
        state_dict_device: Optional[torch.device] = None,
328
        clear_autocast_cache: bool = False,
329
        force_input_to_fp32: bool = False,
330
        verbose: bool = False,
331
        cpu_offload: bool = False,
332
        offload_config: Optional[OffloadConfig] = None,
333
        state_dict_on_rank_0_only: bool = False,
334
    ):
335
336
337
338
339
340
341
        try:
            import torch._C

            torch._C._log_api_usage_once("fairscale.fsdp")
        except ImportError:
            pass

342
        init_start = time.time()
343
        super().__init__()
344
        self.process_group = process_group or get_process_group_cached()
345
346
347
348
        # If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
        # the rest of operations. The overlap feature in the backward propagation is disabled.
        if process_group_reduce_scatter == ProcessGroupName.default:
            self.process_group_reduce_scatter = self.process_group
349
350
        # If ProcessGroupName.reduce_scatter is passed in, the reduce_scatter use a seperate process group
        # so that the overlap feature in the backward propagagion is enabled.
351
352
353
        elif process_group_reduce_scatter == ProcessGroupName.reduce_scatter:
            self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
        else:
354
355
356
357
358
359
360
361
362
363
364
            # If a specific process group is passed in, the reduce_scatter will use the passed in process group.
            if isinstance(process_group_reduce_scatter, ProcessGroup):
                self.process_group_reduce_scatter = process_group_reduce_scatter
            else:
                if not hasattr(process_group_reduce_scatter, "allgather") and hasattr(
                    process_group_reduce_scatter, "rank"
                ):
                    # Likely a dummy pg for unit test
                    self.process_group_reduce_scatter = process_group_reduce_scatter
                else:
                    raise TypeError("unsupported type for reduce_scatter process group")
365
366
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
367
368
369
370
371
372
373
374
375
376
377
        # In a unit test dummy enviromnent, the process_group_reduce_scatter can be None.
        if self.process_group_reduce_scatter is not None:
            reduce_scatter_group_size = self.process_group_reduce_scatter.size()
            # Roll back to use the default process group for reduce scatter operation when the world size and reduce scatter process group size are differnt.
            if self.world_size != reduce_scatter_group_size:
                self.process_group_reduce_scatter = self.process_group
                logging.warn(
                    "Rolled back to use the default process group for the reduce scatter operation because the reduce_scatter process group"
                    f"size is {reduce_scatter_group_size}, which is different with the world size {self.world_size}. Please make sure the process_group"
                    "parameter uses all the available ranks for the optimized performance."
                )
378
        self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
379
        self.disable_reshard_on_root = disable_reshard_on_root
380
381
382
        self.mixed_precision = mixed_precision
        self.fp32_reduce_scatter = fp32_reduce_scatter
        self.flatten_parameters = flatten_parameters
383
        self.move_params_to_cpu = move_params_to_cpu or cpu_offload
384
        self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
385
        self.buffer_dtype = buffer_dtype or self.compute_dtype
386
        self.move_grads_to_cpu = self.move_params_to_cpu if move_grads_to_cpu is None else move_grads_to_cpu
387
        self.bucket_cap_mb = bucket_cap_mb
388
        self.compute_device = compute_device or _get_default_cuda_device(module)
389
390
        self.uncollected_opt_state: Dict[int, Dict] = {}
        self.no_broadcast_optim_state = no_broadcast_optim_state
391
        self.state_dict_device = state_dict_device or self.compute_device
392
        self.clear_autocast_cache = clear_autocast_cache
393
        self.force_input_to_fp32 = force_input_to_fp32
394
        self.verbose = verbose
395
        self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
396
        # Experimental feature for now. Use at your own risk.
397
        self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
398

399
        self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
400
        self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
401
402

        self.numel_padded_per_param: List[int] = []
403
        self._tstart = time.time()
404
405
406
407

        if self.fp32_reduce_scatter and not self.mixed_precision:
            raise ValueError("fp32_reduce_scatter requires mixed_precision=True")

408
409
410
        if self.ssd_offload and not self.flatten_parameters:
            raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")

411
412
413
414
        # skip validation if the process group was created above
        if process_group:
            validate_process_group(self.compute_device, self.process_group)

415
        # enable pytorch sync_bn just in case model contains sync_bn layers.
416
        enable_pytorch_sync_bn(module)
417
418
419
420

        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
421
422
423
424
425
426
        param_names = []
        params = []
        for param_name, param in module.named_parameters():
            if not hasattr(param, "_is_sharded"):
                param_names.append(param_name)
                params.append(param)
427

428
        self._has_params = len(params) > 0
429
        self._has_shared_params = False
430

431
432
433
        # TODO(anj): Should we conditionally do this only if we have params?
        # TODO(anj): Figure out if we can allocate the buffer during sharding.
        self.buffer_size = sum(p.numel() for p in params)
434
        self.ssd_directory = tempfile.gettempdir()
435
436
        if self.ssd_offload:
            assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
437
438
            if offload_config and offload_config.dir:
                self.ssd_directory = offload_config.dir
439
440
441
            self.move_grads_to_cpu = True
            self.move_params_to_cpu = True

442
443
444
445
446
        # For now, it is either all flatten or none flatten. This will be extended to
        # multiple flatten groups in my next PR.
        to_be_flatten_params: List[List[Parameter]] = [[]]
        non_flatten_params = params
        param_name_groups = [[n] for n in param_names]
447
        if self.flatten_parameters:
448
449
450
451
452
            to_be_flatten_params = [params]
            non_flatten_params = []
            param_name_groups = [param_names]
        del param_names

453
454
455
        self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
            module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
        )
456
457
458
459
460
461
462
463
464
465
466
467
        del module  # free original module in case it helps garbage collection

        # Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
        # params for doing sharding, gradient hooks, etc. Note, the ordering of the
        # list matters: flatten params are always in the front.
        #
        # The self._num_flatten_params and self._param_name_groups are computed
        # and kept here to support summon_full_params and shard-to-full weight
        # consolidation.
        self.params = cast(List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params
        self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
        self._param_name_groups = param_name_groups
468
469
470
471
472
473
474
475
476
477
478
479

        # Shard module parameters in place
        self._shard_parameters_()

        # Make sure all parameters are sharded.
        for n, p in self.named_parameters():
            assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"

        self._reset_lazy_init()

        # Flag to indicate if we require gradient reduction in the backward
        # pass. This will be False when inside the no_sync context manager.
480
        self._require_backward_grad_sync: bool = True
481

482
        # Enum to indicate if we're in the forward/backward pass, idle, etc.
483
484
        self.training_state = TrainingState.IDLE

485
486
487
        # Flag to indicate if the full params are gathered.
        self.has_full_params: bool = False

488
489
        # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
        # prefix and before load_state_dict() to add it back.
490
        self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only))
491
492
493
494
495
        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)

        # Flag to indicate whether state_dict() should automatically summon the
        # full params. This defaults to True, but may be set to False if the
        # user explicitly requests the local state dict via local_state_dict().
496
497
        # TODO(anj): This should by default be set to False for ssd_offload=True
        # unless we are in the summon_full_params context.
498
        self._return_full_state_dict = True
499
500
        init_end = time.time()

501
        logging.debug(
502
503
            f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
        )
504

505
        # Flag to guard against preparing gradients multiple times per iteration.
506
507
508
        # This is reset at the end of the backward pass.
        self._pre_backward_hook_has_run = False

509
510
511
512
513
514
        # Free all params at the end of initialization.
        if self.ssd_offload:
            for m in self.modules():  # includes self
                if isinstance(m, FullyShardedDataParallel):
                    m._free_ssd_offload()

515
516
    def _get_gradient_predivide_factor(self, world_size: int) -> float:
        factor: int = 1
517
        while world_size % factor == 0 and world_size / factor > factor:
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
            factor *= 2
        return float(factor)

    def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None:
        """Allowing user to override the pre and post divide factors.

        Args:
            pre (float): divide factor before the reduction.
            post (float): divide factor after the reduction.
            recursive (bool): recursively set it for all child FSDP instances or not.
        """
        self.assert_state(TrainingState.IDLE)
        if recursive:
            for module in self.modules():
                if isinstance(module, FullyShardedDataParallel) and module != self:
                    module.set_gradient_divide_factors(pre, post, False)
        self.gradient_predivide_factor = pre
        self.gradient_postdivide_factor = post
536

537
    @property
538
    def module(self) -> FlattenParamsWrapper:
539
        """make model.module accessible, just like DDP."""
540
541
        assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
        return self._fsdp_wrapped_module
542

543
    def append_shared_param(self, p: Parameter) -> None:
544
        """Add a param that's already owned by another FSDP wrapper.
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

            .. warning:: This is experimental!

            This only works with all sharing FSDP modules are un-flattened.

            p must to be already sharded by the owning module.

            Check the corresponding unit test to see how is it used and tested.
            In particular, the sharing FSDP wrappers are "siblings" not "parent"
            and "child" of each other in the nested module structure.

        Args:
            p (Parameter):
                The shared parameter.
        """
        assert self._is_root is None
        assert not self.flatten_parameters
        assert isinstance(p, Parameter)
        assert p._is_sharded
        p._is_shared = True
        assert (
            len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0
        ), "Must have at least 1 non-shared param."
        self.params.append(p)
569
570
571
572
573
574
575
576
        self._has_shared_params = True

    def non_shared_params(self) -> List[nn.Parameter]:
        """Return the list of non-shared parameters."""
        if self._has_shared_params:
            return list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))
        else:
            return self.params
577

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
        """
        Applies ``fn`` recursively to every submodule (as returned by
        ``.children()``) as well as self. Typical use includes initializing the
        parameters of a model.

        Compared to ``torch.nn.Module.apply``, this version additionally gathers
        the full parameters before applying ``fn``. It should not be called from
        within another ``summon_full_params`` context.

        Args:
            fn (nn.Module): function to be applied to each submodule

        Returns:
            Module: self
        """
        is_uninitialized = self._is_root is None
        self.assert_state(TrainingState.IDLE)
        with self.summon_full_params(recurse=False):
            return_value = super().apply(fn)
        # summon_full_params will call _lazy_init, which sets _is_root. However,
        # apply() may be called directly on children instances to do weight
        # init, so we should reset the _is_root flag in this case.
        if is_uninitialized and self._is_root:
            for module in self.modules():
                if isinstance(module, FullyShardedDataParallel):
                    module._reset_lazy_init()
        return return_value

    def _cast_buffers(
        self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
    ) -> None:
        """Move all buffers to the given *device* and *dtype*.

        If *device* or *dtype* are not given, then they will default to
        ``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
        case of nested FSDP instances, we will respect the child instance's
        ``compute_device`` and ``buffer_dtype`` configuration.

        Args:
            device (torch.device, Optional):
                device to cast buffers to (defaults to compute_device)
            dtype (torch.dtype, Optional):
                dtype to cast buffers to (defaults to buffer_dtype)
            memo (Set, Optional):
                set of modules that have already been processed
        """
        if memo is None:
            memo = set()
        for module in self.modules():
            if module is not self and isinstance(module, FullyShardedDataParallel):
                # Allow any child FSDP instances to handle their own buffers.
                module._cast_buffers(device=device, dtype=dtype, memo=memo)
            elif module not in memo:
                memo.add(module)
                for name, buf in module.named_buffers(recurse=False):
                    if buf is None:
                        continue
                    buf = buf.to(device=device or self.compute_device)
                    if torch.is_floating_point(buf):
                        buf = buf.to(dtype=dtype or self.buffer_dtype)
                    setattr(module, name, buf)
640
641
642

    @property
    def params_with_grad(self) -> List[Parameter]:
643
        """[p for p in self.parameters() if p.grad is not None]"""
644
645
646
647
648
649
650
651
652
653
        return [p for p in self.parameters() if p.grad is not None]

    @torch.no_grad()
    def clip_grad_norm_(
        self,
        max_norm: Union[float, int],
        norm_type: Union[float, int] = 2.0,
        # filter_params_fn: Callable[[Any], Any] = None,
    ) -> torch.Tensor:
        """
Myle Ott's avatar
Myle Ott committed
654
655
656
        Clip all gradients at this point in time. The norm is computed over all
        gradients together, as if they were concatenated into a single vector.
        Gradients are modified in-place.
657

Myle Ott's avatar
Myle Ott committed
658
        Args:
659
            max_norm (float or int): max norm of the gradients
Myle Ott's avatar
Myle Ott committed
660
661
            norm_type (float or int): type of the used p-norm. Can be ``'inf'``
                for infinity norm.
662
663
664
665

        Returns:
            Total norm of the parameters (viewed as a single vector).

Myle Ott's avatar
Myle Ott committed
666
667
668
669
670
671
        .. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
            handles the partitioning and multiple devices per rank under the
            hood. The default torch util is not applicable here, because each
            rank only has a partial view of all the grads in the model, so
            calling it in the OSS context would lead to different scaling being
            applied per subset of model parameters.
672

Myle Ott's avatar
Myle Ott committed
673
674
        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
675
        """
676
677
678
679
        # We don't call torch.cuda.synchronize() here, since clipping can be
        # inside the train loop and we probably don't want to force a GPU-CPU sync.
        # _lazy_init should be sufficient, since it will force the other streams
        # to sync with the default stream (via _wait_for_previous_optim_step).
680
        self._lazy_init()
681
        assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
682
        self.assert_state(TrainingState.IDLE)
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702

        max_norm = float(max_norm)
        norm_type = float(norm_type)
        params_with_grad = self.params_with_grad
        if not self.children_share_process_group:
            raise NotImplementedError(
                "clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work"
            )
        # Computes the max norm for this shard's gradients and sync's across workers
        local_norm = calc_grad_norm(params_with_grad, norm_type).cuda()
        if norm_type == inf:
            total_norm = local_norm
            dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
        else:
            total_norm = local_norm ** norm_type
            dist.all_reduce(total_norm, group=self.process_group)
            total_norm = total_norm ** (1.0 / norm_type)

        if self.move_grads_to_cpu:
            total_norm = total_norm.cpu()
703

704
705
706
707
708
        # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
        clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
        if clip_coef < 1:
            # multiply by clip_coef
            for p in params_with_grad:
709
710
                assert p.grad is not None
                p.grad.detach().mul_(clip_coef.to(p.grad.device))
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738

        return total_norm

    @torch.no_grad()
    def _shard_parameters_(self) -> None:
        """
        At initialization we wrap a module with full parameters and shard the
        parameters in-place. Sharding is implemented by viewing each parameter
        as a 1D Tensor and retaining only a single slice, where the slice size
        is determined by the number of data parallel workers.

        Wrapping modules with many small parameters (or with a very large data
        parallel world size) will result in many small parameter shards and slow
        performance. In this case it's better to set *``flatten_parameters``* to
        ``True``, so that all of the small parameters in the module are combined
        into a single contiguous Tensor and sharded once.

        After this initial sharding is complete, the user can initialize a
        ``torch.optim.Optimizer`` in the usual way, i.e.::

        .. code-block:: python

            optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)

        The optimizer will see only a single slice of parameters and will thus
        allocate less memory for optimizer state, avoiding redundancy across
        data parallel workers.
        """
739
        self.numel_padded_per_param = []
740
741
742
743
744
745
746
        for p in self.params:
            assert not hasattr(p, "_is_sharded")
            assert p.is_floating_point()
            if self.mixed_precision:
                assert p.dtype == torch.float32

            # If world_size is 1, then we all-reduce grads instead of sharding.
Anjali Sridhar's avatar
Anjali Sridhar committed
747
            p._is_sharded = self.world_size > 1
748
749
            p._orig_size = p.data.size()

Anjali Sridhar's avatar
Anjali Sridhar committed
750
            if not p._is_sharded:
751
                if not self.ssd_offload:
752
753
754
                    p._is_sharded = False
                    self.numel_padded_per_param.append(0)
                    continue
Anjali Sridhar's avatar
Anjali Sridhar committed
755
            p._is_sharded = True
756
757

            # Replace p.data with the relevant shard.
758
            if self.ssd_offload:
759
760
761
                assert isinstance(p, SsdFlatParameter)
                sharded_tensor, num_padded = self._get_shard(p.data)
                p.point_to_resized_tensor(sharded_tensor)
762
                self.numel_padded_per_param.append(num_padded)
763
                p.to_file()
764
765
766
767
768
            else:
                orig_data = p.data
                p.data, num_padded = self._get_shard(p.data)
                self.numel_padded_per_param.append(num_padded)
                free_storage_(orig_data)
769

770
        assert len(self.numel_padded_per_param) == len(self.params)
771

772
773
    def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
        """Return the local shard of a full tensor."""
774
775
776
777
778
779
780
781
782
783
784
785
        # Shard using torch.chunk to match all-gather/reduce-scatter.
        chunks = list(torch.flatten(tensor).chunk(self.world_size))
        while len(chunks) < self.world_size:
            chunks.append(chunks[0].new_empty(0))

        # Determine number of padding elements.
        num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
        assert num_to_pad >= 0, num_to_pad

        shard = chunks[self.rank].clone()
        if num_to_pad > 0:
            shard = F.pad(shard, [0, num_to_pad])
786
        return shard, num_to_pad
787

788
    def extra_repr(self) -> str:
789
790
        repr = (
            f"world_size={self.world_size}, "
791
            f"flatten_parameters={self.flatten_parameters}, "
792
            f"mixed_precision={self.mixed_precision}, "
793
        )
794
795
796
797
798
799
800
        if self.verbose:
            repr = (
                f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
                f"compute_dtype={self.compute_dtype}, "
                f"buffer_dtype={self.buffer_dtype}, "
                f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
                f"compute_device={self.compute_device}"
801
                f"move_params_to_cpu={self.move_params_to_cpu}, "
802
803
804
                f"move_grads_to_cpu={self.move_grads_to_cpu}, "
                f"bucket_cap_mb={self.bucket_cap_mb}, "
                f"clear_autocast_cache={self.clear_autocast_cache}"
805
                f"force_input_to_fp32={self.force_input_to_fp32}"
806
807
            )
        return repr
808
809
810
811
812
813
814
815
816

    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.module, name)

    def __getstate__(self) -> Dict[str, str]:
817
        """Serialize the state of the current FSDP instance.
818
819
820
821
822
823
824
825
826

        Some properties are not serializable (e.g., process groups, streams), so
        we remove them and try to reconstruct them in :func:`__setstate__`.
        """
        state = copy.copy(self.__dict__)
        state["is_sharded"] = [p._is_sharded for p in self.params]
        state["orig_sizes"] = [p._orig_size for p in self.params]
        if state["process_group"] is not None:
            state["process_group"] = "MISSING"  # process_group isn't pickleable
827
828
        if state["process_group_reduce_scatter"] is not None:
            state["process_group_reduce_scatter"] = "MISSING"  # process_group_reduce_scatter isn't pickleable
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        self._reset_lazy_init()
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Intercept state setting and perform needed changes on params."""
        super().__setstate__(state)

        def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
            assert isinstance(p, Parameter)
            p.data = p.data.clone()  # move tensors out of shared memory
            p._is_sharded = is_sharded
            p._orig_size = size
            return p

        self.params = [
            fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
        ]
        del self.is_sharded
        del self.orig_sizes
        self._reset_lazy_init()

850
851
852
853
854
855
856
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Returns an iterator over the module parameters, yielding all the parameters
        part of the model.
        """

        return super().parameters(recurse=recurse)

857
858
859
860
861
    def named_parameters(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Parameter]]:
        """Returns an iterator over the module parameters, yielding both the name of the
        parameter as well as the parameter.

        With FSDP, the `named_parameters` function implemented in `nn.Module` will not
862
        be able to return the name and param when we use flattened parameters unless
863
864
        we call this function under a `summon_full_params` context.

865
        If you want the full param to be returned, you should call this function
866
867
868
869
870
871
872
873
874
875
876
877
878
879
        under a `summon_full_params` context when using flattened or original params.
        """
        named_param = super().named_parameters(*args, **kwargs)
        for name, param in named_param:
            if (
                hasattr(self, "flatten_parameters")
                and self.flatten_parameters
                and hasattr(self, "training_state")
                and self.training_state != TrainingState.SUMMON_FULL_PARAMS
            ):
                yield name, param
            else:
                yield _clean_path(name), param

880
881
882
883
    def __getitem__(self, key: int) -> Any:
        """Forward indexing calls in case the module is a nn.Sequential."""
        return self.module.__getitem__(key)

884
885
886
887
888
889
890
891
892
893
894
895
    @typing.overload
    def state_dict(
        self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
    ) -> Mapping[str, torch.Tensor]:
        ...

    @typing.overload
    def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
        ...

    # Since we have overloads above, we can use Any here.
    def state_dict(self, *args: Any, **kwargs: Any) -> Any:
896
897
898
        """
        Returns the whole (unsharded) state of the module. Parameters are not
        sharded, so the resulting state_dict can be loaded directly by the
Myle Ott's avatar
Myle Ott committed
899
        wrapped Module without any sharding-specific logic. Returned tensors
900
        will be full precision (e.g., FP32).
Myle Ott's avatar
Myle Ott committed
901
902
903

        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
904
        """
905
906
        if torch.cuda.is_available():
            torch.cuda.synchronize()
907
        self._lazy_init()
908
909
910
911

        def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
            if self.mixed_precision:
                self._cast_buffers(dtype=dtype)
912

913
914
        if self._return_full_state_dict:
            if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
915
                with self.summon_full_params(recurse=False, volatile=True):
916
                    maybe_cast_buffers(torch.float32)
917
918
                    state_dict = super().state_dict(*args, **kwargs)
            else:
919
                maybe_cast_buffers(torch.float32)
920
921
                state_dict = super().state_dict(*args, **kwargs)
        else:
922
            maybe_cast_buffers(torch.float32)
923
            state_dict = self.module.flat_state_dict(*args, **kwargs)
924

925
        if self.move_params_to_cpu:
926
927
928
            for k in state_dict.keys():
                state_dict[k] = state_dict[k].cpu()

929
930
        # In case we are in mixed precision, restore buffers back to buffer_dtype.
        maybe_cast_buffers()
931
932
        return state_dict

933
934
935
936
937
938
939
940
941
942
943
944
    @typing.overload
    def local_state_dict(
        self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
    ) -> Mapping[str, torch.Tensor]:
        ...

    @typing.overload
    def local_state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
        ...

    # Since we have overloads above, we can use Any here.
    def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
945
946
947
        """
        Returns the local (sharded) state of the module. Parameters are sharded,
        so the resulting state_dict can only be loaded after the Module has been
948
        wrapped with FSDP.
949
        """
950
951
952
953
954
        with contextlib.ExitStack() as stack:
            # Tell any nested FSDP instances not to auto summon full params.
            for module in self.modules():  # includes self
                if isinstance(module, FullyShardedDataParallel):
                    stack.enter_context(module._no_return_full_state_dict())
955
956
957
            # We need to specially call FSDP's state_dict function in case
            # self.state_dict is a function from a child class of FSDP.
            return FullyShardedDataParallel.state_dict(self, *args, **kwargs)
958
959
960
961
962

    @contextlib.contextmanager
    def _no_return_full_state_dict(self) -> Generator:
        backup = self._return_full_state_dict
        self._return_full_state_dict = False
963
964
965
966
967

        if self.ssd_offload:
            # Move params from disk to memory before returning the local state dict.
            self._move_params_to_memory()

968
969
970
971
        try:
            yield
        finally:
            self._return_full_state_dict = backup
972

973
974
    def _move_params_to_memory(self) -> None:
        """Move params from disk to CPU."""
975
976
977
        for p in self.params:
            assert isinstance(p, SsdFlatParameter)
            p.to_tensor()
978

979
    def _load_state_dict(
980
981
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
Myle Ott's avatar
Myle Ott committed
982
983
984
985
986
987
        """
        Load a whole (unsharded) state_dict.

        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
        """
988
989
990
991
992
993
994
        if self._return_full_state_dict:
            with self.summon_full_params():
                return self.module.load_state_dict(state_dict, strict)
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            return self.module.load_state_dict(state_dict, strict)
995

996
997
998
999
1000
    def load_state_dict(
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
        return self._load_state_dict(state_dict, strict)

1001
1002
1003
1004
    def load_local_state_dict(
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
        """Load a local (sharded) state_dict."""
1005
1006
1007
1008
1009
        with contextlib.ExitStack() as stack:
            # Tell any nested FSDP instances not to auto summon full params.
            for module in self.modules():  # includes self
                if isinstance(module, FullyShardedDataParallel):
                    stack.enter_context(module._no_return_full_state_dict())
1010
            output = self._load_state_dict(state_dict, strict)
1011
        return output
1012
1013
1014
1015

    @contextlib.contextmanager
    def no_sync(self) -> Generator:
        """
1016
        A context manager to disable gradient synchronizations across FSDP
1017
1018
        processes. Within this context, gradients will be accumulated on module
        variables, which will later be synchronized in the first
1019
1020
        forward-backward pass after exiting the context.

1021
        .. note:: This likely results in higher memory usage because FSDP will
1022
1023
            accumulate the full model gradients (instead of gradient shards)
            until the eventual sync.
1024
1025
1026
1027

        .. note:: Gradient accumulation can be done without this context,
            avoiding the extra GPU memory overhead, but with the extra
            networking overhead.
1028
1029
1030
1031
        """
        self._lazy_init()
        assert self._is_root, "no_sync on inner FSDP is not supported"
        self.assert_state(TrainingState.IDLE)
1032
        # This instance may wrap other FSDP instances and we
1033
1034
1035
1036
        # need to set all of them to accumulate gradients.
        old_flags = []
        for m in self.modules():  # includes self
            if isinstance(m, FullyShardedDataParallel):
1037
1038
                old_flags.append((m, m._require_backward_grad_sync))
                m._require_backward_grad_sync = False
1039
1040
1041
1042
        try:
            yield
        finally:
            for m, old_flag in old_flags:
1043
                assert m._require_backward_grad_sync is False
1044
                m._require_backward_grad_sync = old_flag
1045

1046
    @contextlib.contextmanager
1047
    def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
1048
        """
1049
1050
        A context manager to expose full params for the current FSDP instance.
        Can be useful *after* forward/backward for a model to get the params for
1051
1052
        additional processing or checking. Parameters will be gathered in full
        precision (e.g., FP32).
1053

1054
        .. note:: This can be used on inner FSDPs.
1055

1056
1057
        .. note:: This can *not* be used within a forward or backward pass. Nor
            can forward and backward be started from within this context.
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070

        .. note:: The full parameters will be freed after the context manager
            exits; it is up to the caller to clone them if needed.

        .. note:: The full parameters can be modified, but only the portion
            corresponding to the local param shard will persist after the
            context manager exits (unless ``volatile=True``, in which case there
            are no guarantees about persistence).

        Args:
            recurse (bool, Optional): recursively summon all params for nested
                FSDP instances (default: True)
            volatile (bool, Optional): if ``True``, modifications to params are
1071
                not guaranteed to persist after the context manager exists;
1072
                enabling this can be slightly more efficient (default: False)
1073
        """
1074
1075
        if recurse:
            with contextlib.ExitStack() as stack:
1076
                # Summon all params for any nested FSDP instances.
1077
1078
                for module in self.modules():
                    if isinstance(module, FullyShardedDataParallel):
1079
1080
                        stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
                # Yield to the caller, with full params in all nested instances.
1081
                yield
1082
            # Exiting from the ExitStack will re-shard params.
1083
1084
1085
1086
1087
1088
1089
1090
            return
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            self.assert_state(TrainingState.IDLE)
            # Set the state so that we assert when trying to go into
            # forward/backward.
            self.training_state = TrainingState.SUMMON_FULL_PARAMS
1091
            full_tensors = self._rebuild_full_params(force_full_precision=True)
1092
            assert full_tensors is not None
1093
            with contextlib.ExitStack() as stack:
1094
                if self.module.is_flattened:
1095
                    # Update flattened views to point to fully-sized tensors. We
1096
                    # use self.params instead of full_tensors since the
1097
                    # latter may contain padding.
1098
1099
1100
1101
1102
                    stack.enter_context(
                        self.module.unflatten_params(
                            flat_params=[p.data for p in self.params[: self._num_flatten_params]]
                        )
                    )
1103
1104
1105
1106
                try:
                    yield
                finally:
                    stack.close()
1107
1108
1109
                    non_shared_params = self.params
                    # filter out shared params for all but the owner FSDP module.
                    if len(full_tensors) < len(non_shared_params):
1110
                        non_shared_params = self.non_shared_params()
1111
1112
1113
1114
                    assert len(full_tensors) == len(
                        non_shared_params
                    ), f"{len(full_tensors)} vs. {len(non_shared_params)}"
                    for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors):
1115
1116
1117
                        if not volatile:
                            # Copy any changes made to the full params back into
                            # the corresponding local shards.
1118
                            local_shard, _ = self._get_shard(full_tensor)
1119
1120
1121
                            p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
                        if safe_to_free:
                            free_storage_(full_tensor)
1122
                    self.has_full_params = False
1123
1124
1125
                    if self.ssd_offload:
                        # Store tensors in the SSD buffer and free param storage.
                        for p in self.params:
1126
1127
1128
1129
                            assert isinstance(p, SsdFlatParameter)
                            p.to_file()
                    else:
                        self._use_fp32_param_shard()
1130
                    self.training_state = TrainingState.IDLE
1131

1132
1133
1134
1135
1136
    def _reset_lazy_init(self) -> None:
        """Reset instance so :func:`_lazy_init` will run on the next forward."""
        self._is_root: Optional[bool] = None
        self._streams: Dict[str, torch.cuda.Stream] = {}
        self._reducer: Optional[ReduceScatterBucketer] = None
1137
1138
1139
        for p in self.params:
            if hasattr(p, "_fp32_shard"):
                del p._fp32_shard  # reset _init_param_attributes
1140
        self._output_pre_backward_hook_registered: Optional[List] = None
1141
        self.reshard_after_forward = self._orig_reshard_after_forward
1142
1143
1144

    def _lazy_init(self) -> None:
        """Initialization steps that should happen lazily, typically right
1145
        before the first forward pass.
1146
        """
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        # Initialize param attributes lazily, in case the param's dtype or
        # device changes after __init__.
        for p in self.params:
            self._init_param_attributes(p)

        # Initialize _is_root and setup streams. These steps would ideally
        # happen in __init__, but _is_root can only be determined after the
        # entire model hierarchy is setup, thus we run it lazily.
        if self._is_root is None:
            self._set_is_root()
            self._setup_streams()
1158
            self._setup_output_hook_list()
1159
1160

        if self._is_root:
1161
1162
1163
1164
            # Buffers stay on GPU, and don't get sharded. Since _cast_buffers
            # applies recursively, we only call this from the root instance.
            self._cast_buffers()

1165
1166
1167
1168
1169
            if self.disable_reshard_on_root:
                # Don't free the full params for the outer-most (root) instance,
                # since those params will be needed immediately after for the
                # backward pass.
                self.reshard_after_forward = False
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189

            # Due to the use of streams, we need to make sure the previous
            # ``optim.step()`` is done before we all-gather parameters.
            self._wait_for_previous_optim_step()

    @torch.no_grad()
    def _init_param_attributes(self, p: Parameter) -> None:
        """
        We manage several attributes on each Parameter instance. The first two
        are set by :func:`_shard_parameters_`:

            ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
                if the Parameter is intentionally not sharded (in which case we
                will all-reduce grads for this param).
            ``_orig_size``: the size of the original Parameter (before sharding)

        The remaining attributes are set here:
            ``_fp32_shard``: a single shard of the parameters in full precision
                (typically FP32, but this is dependent on the dtype of the model
                as it's passed in by the user). This can be on CPU or GPU
1190
                depending on the value of *``move_params_to_cpu``*.
1191
1192
1193
            ``_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
                This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
                if params are offloaded to CPU.
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
            ``_full_param_padded``: the full weight (padded to be evenly
                divisible by ``world_size``), used for computation in the
                forward and backward pass. This will be resized in place and
                only materialized (via all-gather) as needed.
        """
        assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size")
        if hasattr(p, "_fp32_shard"):
            return

        # A single shard of the parameters in full precision.
1204
1205
1206
1207
1208
        # TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
        #            p.data returns a pointer to a handle, and that handle has it's
        #            ref count incremented by p._fp32_shard. So this tensor will
        #            never be freed even if we do p.to_disk(). investigate after
        #            PR #887 is merged
1209
1210
1211
1212
        p._fp32_shard = p.data

        if self.mixed_precision:
            assert p._fp32_shard.dtype == torch.float32
1213
1214
        if self.move_params_to_cpu:
            assert p._fp32_shard.device == torch.device("cpu")
1215

1216
1217
            # We don't pin memory when using ssd_offload since that results in OOM when
            # the memory requirements of a model are larger than host memory.
1218
            if not self.ssd_offload:
1219
1220
1221
                # If we plan to keep the FP32 parameters on CPU, then pinning
                # memory allows us to later use non-blocking transfers when moving
                # the FP32 param shard to compute_device.
1222
                p._fp32_shard = p._fp32_shard.pin_memory()
1223
                p.data = p._fp32_shard
1224
1225

        if self.move_params_to_cpu or self.mixed_precision:
1226
1227
1228
1229
1230

            # In mixed precision mode, we maintain a reduced precision
            # (typically FP16) parameter shard on compute_device for performing
            # the computation in the forward/backward pass. We resize the
            # storage to size 0 at init (here) and re-materialize (by copying
1231
1232
            # from _fp32_shard) as needed. If offloading params to CPU, the
            # dtype of the fp16 shard will depend on the *`compute_dtype`*.
1233
            p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
1234
            free_storage_(p._fp16_shard)
1235
1236
1237
1238
1239
1240
1241
1242

        if self.mixed_precision:
            assert p._fp32_shard.dtype == torch.float32

        if not self.mixed_precision and not self.move_params_to_cpu:
            # use _fp32_shard if you are not in using mixed precision or
            # offloading params and grads to CPU.
            p._fp16_shard = None
1243
1244
1245
1246
1247
1248
1249
1250
1251

        # We also maintain a full-sized parameter of type self.compute_dtype
        # (FP16 for mixed_precision or FP32 otherwise). We resize the
        # storage to size 0 at init (here) and only materialize as needed. The
        # storage may contain padding elements so that it is evenly divisible by
        # world_size, although these padding elements will be removed before the
        # relevant computation.
        if p._is_sharded:
            p._full_param_padded = torch.zeros(
1252
                p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
1253
1254
1255
            )
            free_storage_(p._full_param_padded)

1256
        if self.move_grads_to_cpu and self.training:
1257
1258
1259
            # We can optionally move the grad shard to CPU during the backward
            # pass. In this case, it's important to pre-allocate the CPU grad
            # shard in pinned memory so that we can do a non-blocking transfer.
1260
            # This is only needed during training and not evaluation.
1261
            if self.ssd_offload:
1262
1263
1264
1265
1266
1267
                assert isinstance(p, SsdFlatParameter)
                # Gradients also need to be offloaded to SSD otherwise it can result in
                # OOMs when the memory requirements of a model are larger than host memory.
                p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
                p._cpu_grad.set_file_params(p.filename + "_grad", 0)
                p._cpu_grad.to_file()
1268
1269
            else:
                p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
1270
1271
1272
1273

    def _set_is_root(self) -> None:
        """If ``True``, implies that no other :class:`FullyShardedDataParallel`
        instance wraps this one. Called once by :func:`_lazy_init`.
Myle Ott's avatar
Myle Ott committed
1274
1275
1276
1277
        Also sets self.children_share_process_group = True if all child
        instances share the same process group. If some child instances use a
        different process group, self.clip_grad_norm_ will raise an error.
        """
1278
1279
        if self._is_root is not None:
            return
1280
        # No FSDP instance wraps this, else _is_root would be set to False.
1281
        self._is_root = True
1282
1283
1284
1285
1286
1287
1288
        # If final backward callback is never been queued, state should be IDLE.
        # If final backward callback is queued, the callback should be finished
        # and the state was reset to be IDLE.
        # This should be asserted at the beginning of forward pass in the root instance only.
        # For children instances, if they are checkpointed, state will not be reset to
        # IDLE after each inner forward/backward.
        self.assert_state(TrainingState.IDLE)
1289
1290
        # As the root, we now set all children instances to False and
        # give them a closure to try to queue a wait_for_post_backward.
1291
1292
        self.children_share_process_group = True
        for n, m in self.named_modules():
1293
            # `n != ""` excludes self.
1294
            if n != "" and isinstance(m, FullyShardedDataParallel):
1295
1296
1297
1298
1299
                # We relax the assert for non-root instance, when the nested inialized module is wrapped
                # again in FSDP later, for example after training to run inference.
                assert m._is_root is None or not m._is_root
                if m._is_root is None:
                    m._is_root = False
1300
1301
1302
                if m.process_group != self.process_group:
                    self.children_share_process_group = False

1303
1304
1305
1306
1307
1308
                # if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
                # Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
                m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
                    (m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
                )

1309
1310
1311
1312
    def _setup_streams(self) -> None:
        """Create streams to overlap data transfer and computation."""
        if len(self._streams) > 0 or not self._is_root:
            return
1313
1314
1315
1316
1317
1318
1319
1320
1321

        if torch.cuda.is_available():
            # Stream to move main FP32 params (may be on CPU) to FP16 for forward.
            self._streams["fp32_to_fp16"] = torch.cuda.Stream()
            # Stream for all-gathering parameters.
            self._streams["all_gather"] = torch.cuda.Stream()
            # Stream for overlapping grad reduction with the backward pass.
            self._streams["post_backward"] = torch.cuda.Stream()

1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
        # Helper for bucketing reduce-scatter ops. This is also shared with
        # children instances to improve bucket utilization.
        self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
        # We share streams with all children instances, which allows them to
        # overlap transfers across the forward pass without synchronizing with
        # the default stream.
        for n, m in self.named_modules():
            if n != "" and isinstance(m, FullyShardedDataParallel):
                m._streams = self._streams
                m._reducer = self._reducer

1333
    def _setup_output_hook_list(self) -> None:
1334
1335
        """set up a list to avoid registering pre-backward hooks
        incorrectly.
1336
1337
1338
1339
1340
1341
1342
        """
        assert self._is_root, "This should only be called on the root"
        self._output_pre_backward_hook_registered = []
        for n, m in self.named_modules():
            if n != "" and isinstance(m, FullyShardedDataParallel):
                m._output_pre_backward_hook_registered = self._output_pre_backward_hook_registered

1343
1344
1345
1346
1347
1348
    def _wait_for_previous_optim_step(self) -> None:
        """
        The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
        instance) needs to synchronize with the default stream to ensure the
        previous optimizer step is done.
        """
1349
1350
        if not torch.cuda.is_available():
            return
1351
        if self.mixed_precision or self.move_params_to_cpu:
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
            self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
        else:
            self._streams["all_gather"].wait_stream(torch.cuda.current_stream())

    def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        self._lazy_init()

        # Start of a forward pass.
        self.training_state = TrainingState.FORWARD

1362
1363
        # For root and mixed precision, we convert the input to FP16 (no_grad is needed for
        # the conversion).
1364
        if self._is_root and self.mixed_precision:
1365
1366
1367
1368
1369
1370
1371
            args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)

        # If enabled, convert the input to FP32 if we are in full precision.
        # no_grad is not used because the input might be for a non-root instance,
        # which mean autograd needs to go through the conversion.
        if self.force_input_to_fp32 and not self.mixed_precision:
            args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384

        # All-gather full parameters. This will also transfer FP32 parameters to
        # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
        self._rebuild_full_params()

        # Register backward hooks to reshard params and reduce-scatter grads.
        # These need to be re-registered every forward pass.
        self._register_post_backward_hooks()

        outputs = self.module(*args, **kwargs)

        if self.reshard_after_forward:
            self._free_full_params()
1385
            if self.mixed_precision or self.move_params_to_cpu:
1386
                self._free_fp16_param_shard()
1387
1388
1389
1390
1391
1392
1393
1394
1395

        # Switch to main FP32 param shard. We maintain this invariant throughout
        # the code, i.e., ``p.data == p._fp32_shard`` after each function. This
        # also ensures that after the first forward, the optimizer state will be
        # initialized with the correct dtype and (sharded) size, since optimizer
        # state is typically initialized lazily in ``optim.step()``.
        self._use_fp32_param_shard()

        # Register pre-backward hooks to all-gather the params for the backward
1396
1397
1398
1399
1400
1401
1402
        # pass (if output's grad was needed). This won't register anything if
        # we are in eval mode.
        #
        # Some model does forward pass multiple times, we need to register the
        # pre-backward hook on every output since the last output's hook has to
        # fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
        # to prevent repeated overhead from multiple hook callbacks.
1403
1404
1405
1406
1407
        outputs = self._register_pre_backward_hooks(outputs)

        # Done with a forward pass.
        self.training_state = TrainingState.IDLE

1408
1409
1410
1411
1412
1413
        # Only need to clear cache during forward. During backward, the cache is not used.
        # TODO (Min): Future PyTorch versions may provide a way to completely disable this
        #     cache. Update this when that's available.
        if self.clear_autocast_cache:
            torch.clear_autocast_cache()

1414
1415
        self._free_ssd_offload()

1416
1417
        return outputs

1418
1419
1420
    @torch.no_grad()
    def _free_ssd_offload(self) -> None:
        if self.ssd_offload:
1421
1422
1423
            for p in self.params:
                assert isinstance(p, SsdFlatParameter)
                p.to_file(permit_when_tensor_none=True)
1424

1425
1426
    def _register_pre_backward_hooks(self, outputs: Any) -> Any:
        """Register pre-backward hook to run before the wrapped module's
1427
1428
1429
1430
1431
        backward. Hooks should be attached to all outputs from the forward.

        Returns:
            outputs: new outputs with hooks registered if they requires gradient.
        """
1432
1433
1434
        if not torch.is_grad_enabled():
            return outputs  # don't register hooks if grad isn't enabled

1435
1436
1437
1438
1439
1440
        if self._is_root:
            # This actually means that only root instance has
            # _post_backward_callback_queued defined. Accidentally accessing this field
            # will assert on all other instances, giving us a nice bug checker.
            self._post_backward_callback_queued = False

1441
        def _pre_backward_hook(*unused: Any) -> None:
1442
1443
1444
1445
1446
1447
1448
            # try to queue final backward callback only once for root, so
            # that final backward callback is attached to the outer most
            # backward graph task and called after all the backward
            # calls are completed.
            if self._is_root:
                self._queue_wait_for_post_backward()

1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
            # All-gather full parameters or switching to the full params.
            #
            # This needs to be done on every pre_backward hook, even within the same
            # iteration (i.e. for checkpointed, multiple forward pass modules). This is
            # because after the forward pass (i.e. in checkpoint inner graph), we always
            # switch to fp32_shard in the ``forward`` function.
            #
            # We used to do this only after the ``self._pre_backward_hook_has_run``
            # boolean guard below, which is incorrect. It worked in pytorch < 1.9 for
            # some unknown reason, but pytorch 1.10 nightly exposed this bug.
            #
            # Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are
            # idempotent.  So in case they are called unnecessarily, they don't incur much
            # overhead.
1463
            if self.reshard_after_forward:
1464
1465
1466
1467
1468
1469
                self._rebuild_full_params()
            else:
                self._use_full_params()

            # Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
            # it is multiple outputs or multiple forward passes).
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
            if not self._pre_backward_hook_has_run:
                self._pre_backward_hook_has_run = True
                # Start of a backward pass for the first time in an iteration.
                self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
                # Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
                self._prep_grads_for_backward()

            # Transition to BACKWARD_PRE state if currently IDLE. We can transition from BACKWARD_POST
            # to IDLE when FSDP is within activation checkpointing and called multiple times, due to the
            # extra forward pass for re-computation.
            if self.training_state == TrainingState.IDLE:
                self.training_state = TrainingState.BACKWARD_PRE
            self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
1483

1484
1485
        _registered = 0

1486
        def _register_hook(t: torch.Tensor) -> torch.Tensor:
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
            # We don't register the pre_backward hook on the same tensor that has been
            # returned from an inner FSDP, unless it is the first one. This does
            # not cover all problematic cases though. A tensor not from an inner
            # FSDP can cause problems too:
            # ```
            #   x = layer1(input)
            #   state = [x]  # better change to x.detach(), not fixed by the following if-condition
            #   x = inner_fsdp_module_layer2(x)
            #   state.append(x)  # better change to x.detach(), but fixed by the following if-condition
            #   x = layer3(x)
            #   return x, state
            # ```
            # The tensors in `state`, if not detached, can be registered with
            # backward hooks (in addition to the `x` on the last line). In that case,
            # pre-backward hook can fire multiple times in the order that causes
            # the outer FSDP to crash.
            #
            # The best practice is for modules to be wrapped by FSDP to return 1 and only
            # 1 tensor to be used for backward. All other tensors returned should be
            # detached.
1507
1508
1509
            nonlocal _registered
            assert self._output_pre_backward_hook_registered is not None
            if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered):
1510
                t.register_hook(_pre_backward_hook)
1511
1512
                self._output_pre_backward_hook_registered.append(id(t))
                _registered += 1
1513
1514
1515
1516
1517
1518
1519
1520
            return t

        # Attach hooks to Tensor outputs.
        outputs = apply_to_tensors(_register_hook, outputs)

        return outputs

    def _register_post_backward_hooks(self) -> None:
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
        """
        Register backward hooks to reshard params and reduce-scatter grads.

        This is called during forward pass. The goal is to attach a hook
        on each of the parameter's gradient generating function (``grad_acc``
        below) so that the hook is called *after* all gradients for that
        param are computed.

        Goals:

        1. We want the hook to fire once and only once *after* all gradients
        are accumulated for a param.
        2. If it fires more than once, we end up incorrectly shard the grad
        multiple times. (could lead to dimension too small)
        3. If it fires once but too early or doesn't fire, we leave gradients
        unsharded. (could lead to dimension too large)

        Due to multiple-pass forward, this function can be called on
        the same parameter multiple times in a single forward pass. If we register
        the hook multiple time, we end up getting called multiple times. We
        could try to get a new hook every time and delete the previous one
        registered. However, due to *unknown reason* (I have debugged it for
        a long time!), in mixed precision mode, we get two different ``grad_acc``
        objects below during different calls of this function (in the same
        forward pass). If we keep the last one, the hook end up firing too
        early. In full precision mode, we luckily get the *same* ``grad_acc``
        object, so deleting and re-registering still ensured the hook fire
        once after all gradients are generated.

        Empirically, keep the first hook register per forward pass seems to
        work the best. We do need to remove the hook at the end of the
        backward pass. Otherwise, the next forward pass will not register
        a new hook, which is needed for a new forward pass.
        """
1555
1556
1557
1558
1559
        if not torch.is_grad_enabled():
            return  # don't register grad hooks if grad isn't enabled
        for p in self.params:
            if p.requires_grad:
                if hasattr(p, "_shard_bwd_hook"):
1560
1561
1562
1563
                    continue
                # Register a hook on the first call, empirically, autograd
                # fires it at the end for this param, which makes sense.
                p_tmp = p.expand_as(p)  # Get a grad_fn on p_tmp.
1564
                assert p_tmp.grad_fn is not None
1565
                grad_acc = p_tmp.grad_fn.next_functions[0][0]  # Gets its GradAccumulation object.
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
                handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
                p._shard_bwd_hook = (grad_acc, handle)

    @torch.no_grad()
    def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
        """
        At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
        full gradient for the local batch. The reduce-scatter op will replace
        ``param.grad`` with a single shard of the summed gradient across all
        GPUs. This shard will align with the current GPU rank. For example::

            before reduce_scatter:
                param.grad (GPU #0): [1, 2, 3, 4]
                param.grad (GPU #1): [5, 6, 7, 8]

            after reduce_scatter:
                param.grad (GPU #0): [6, 8]    # 1+5, 2+6
                param.grad (GPU #1): [10, 12]  # 3+7, 4+8

        The local GPU's ``optim.step`` is responsible for updating a single
        shard of params, also corresponding to the current GPU's rank. This
        alignment is created by :func:`_shard_parameters_`, which ensures that
        the local optimizer only sees the relevant parameter shard.
        """
1590
        # First hook callback will see PRE state. If we have multiple params,
1591
1592
        # then subsequent hook callbacks will see POST state.
        self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
1593
        self.training_state = TrainingState.BACKWARD_POST
1594
1595
        if param.grad is None:
            return
1596

1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
        if hasattr(param, "_linked_param"):
            # This links to a shared param. We should finalize the linked param here.
            assert param.shape == (1,), param.shape
            # If the _is_shared flag is set, then this shared weight is indeed being
            # shared between different FSDP wrappers. Otherwise, they are linked but
            # likely in the same FSDP wrapper, which means we shouldn't finalize the
            # linked param..
            if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
                param = param._linked_param

Min Xu's avatar
Min Xu committed
1607
        assert param.grad is not None, param.shape
1608
        if param.grad.requires_grad:
1609
            raise RuntimeError("FSDP only works with gradients that don't require gradients")
1610

1611
        if self._require_backward_grad_sync or self.reshard_after_forward:
1612
            # Free full params. As a special case, we don't free the full params
1613
1614
            # when in a ``no_sync`` context (as inversely indicated by
            # ``self._require_backward_grad_sync``), since the params will not
1615
1616
            # get updated before the next forward. This saves networking
            # bandwidth but uses more GPU memory.
1617
1618
            self._free_full_params([param])

1619
1620
1621
1622
1623
1624
        if self.mixed_precision:
            # This is a no-op if reshard_after_forward is True, since we already
            # free the param shard when rebuilding the full params in the
            # pre_backward_hook.
            self._free_fp16_param_shard([param])

1625
1626
1627
        # Switch to FP32 shard after backward.
        self._use_fp32_param_shard([param])

1628
        if not self._require_backward_grad_sync:
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
            return

        # Wait for all work in the current stream to finish, then start the
        # reductions in post_backward stream.
        self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self._streams["post_backward"]):
            orig_grad_data = param.grad.data

            if self.mixed_precision and self.fp32_reduce_scatter:
                # Cast grad to FP32.
                param.grad.data = param.grad.data.to(param.dtype)

1641
            if self.gradient_predivide_factor > 1:
1642
                # Average grad by world_size for consistency with PyTorch DDP.
1643
                param.grad.data.div_(self.gradient_predivide_factor)
1644
1645
1646

            if param._is_sharded:
                assert self._reducer is not None
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
                # Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
                # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
                # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
                # matter, neglecting rounding.
                grad = param.grad.data
                # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
                #
                # The effect on memory consumption is not usually significant. No extra memory is allocated if this
                # module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
                # called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
                # then we can end up with multiple unsharded gradients allocated and queued for reduction.
                #
                # We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
                # This ensures the `default` stream will wait for the `post_backward` stream to complete the last
                # reduction for this module, before scheduling additional reduction work. Then at most there are two
                # unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
                param.grad = None
                callback_fn = functools.partial(self._post_reduction_hook, param)
1665
                grad_chunks = chunk_and_pad(grad, self.process_group_reduce_scatter.size())
1666
1667
1668
                self._reducer.reduce_scatter_async(
                    grad_chunks, group=self.process_group_reduce_scatter, callback_fn=callback_fn
                )
1669
1670
1671
1672
1673
            else:
                # Currently the only way for _is_sharded to be False is if
                # world_size == 1. This could be relaxed in the future, in which
                # case grads should be all-reduced here.
                assert self.world_size == 1
1674
                self._post_reduction_hook(param, param.grad.data)
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685

            # After _post_backward_hook returns, orig_grad_data will eventually
            # go out of scope, at which point it could otherwise be freed for
            # further reuse by the main stream while the div/reduce_scatter/copy
            # are underway in the post_backward stream. See:
            # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
            orig_grad_data.record_stream(self._streams["post_backward"])

    def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
        """Hook to call on each param after the reduce-scatter."""
        assert torch.cuda.current_stream() == self._streams["post_backward"]
1686
        self.assert_state(TrainingState.BACKWARD_POST)
1687
1688
        if self.gradient_postdivide_factor > 1:
            # Average grad by world_size for consistency with PyTorch DDP.
1689
            reduced_grad.data.div_(self.gradient_postdivide_factor)
1690
1691
1692
1693
        # Cast grad to param's dtype (typically FP32). Note: we do this
        # before the move_grads_to_cpu step so that this entire hook remains
        # non-blocking. The downside is a bit more D2H transfer in that case.
        if self.mixed_precision:
1694
1695
            orig_param_grad_data = reduced_grad.data
            reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
1696
1697
            # Don't let this memory get reused until after the transfer.
            orig_param_grad_data.record_stream(torch.cuda.current_stream())
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711

        if param._is_sharded:
            # Accumulate into the gradient shard.
            if getattr(param, "_saved_grad_shard", None) is None:
                param._saved_grad_shard = reduced_grad.data
            else:
                assert (
                    param._saved_grad_shard.shape == reduced_grad.shape
                ), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}"
                param._saved_grad_shard.data += reduced_grad.data
            reduced_grad = param._saved_grad_shard.data

        # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
        # backwards pass completes, we will set `.grad` to the CPU copy.
1712
        if self.move_grads_to_cpu:
1713
            param._cpu_grad.copy_(reduced_grad.data, non_blocking=True)
1714
            # Don't let this memory get reused until after the transfer.
1715
            reduced_grad.data.record_stream(torch.cuda.current_stream())
1716

1717
1718
    def _queue_wait_for_post_backward(self) -> None:
        """Try to queue a `wait_for_post_backward` callback.
1719

1720
1721
        Only called on root and only queue one callback at the beginning of
        outer most backward.
1722
1723
1724
        """
        assert self._is_root
        if not self._post_backward_callback_queued:
1725
            self.assert_state([TrainingState.IDLE])
1726
1727
1728
            self._post_backward_callback_queued = True
            Variable._execution_engine.queue_callback(self._wait_for_post_backward)

1729
1730
    @torch.no_grad()
    def _wait_for_post_backward(self) -> None:
1731
        """Wait for post-backward to finish. Only called on root instance."""
1732
1733
        # None, backward runtime swallow the assert error, so we use p_assert() here.
        p_assert(self._is_root, "WFPB not called on root")
1734
1735
1736
1737
1738
        # Check if the root module has params and if any of them has
        # the `requires_grad` field set. If `requires_grad=False` for
        # all the params, the post_backward hook will not fire and the
        # state will remain in `TrainingState.BACKWARD_PRE`.
        if any([p.requires_grad for p in self.params]):
1739
1740
1741
1742
            self.assert_state(TrainingState.BACKWARD_POST)
        else:
            self.assert_state(TrainingState.BACKWARD_PRE)

1743
1744
1745
        if self._require_backward_grad_sync:
            # Flush any unreduced buckets in the post_backward stream.
            with torch.cuda.stream(self._streams["post_backward"]):
1746
1747
                p_assert(self._reducer is not None, "WFPB: reducer is None")
                assert self._reducer is not None  # make mypy happy
1748
1749
1750
1751
1752
                self._reducer.flush()
            torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
            if self.move_grads_to_cpu:
                # Wait for the non-blocking GPU -> CPU grad transfers to finish.
                torch.cuda.current_stream().synchronize()
1753
1754
1755
1756
1757
1758
1759

        # A backward pass is done, clean up below.

        # Free reducer buffers.
        if self._reducer is not None:
            self._reducer.teardown()

1760
        def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
1761
1762
            """Helper used below on all fsdp modules."""
            for p in fsdp_module.params:
1763
1764
1765
                if not p.requires_grad:
                    continue
                if hasattr(p, "_shard_bwd_hook"):
1766
                    p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
                    p._shard_bwd_hook[1].remove()
                    delattr(p, "_shard_bwd_hook")

                # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
                # remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
                # remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
                # sync passes, if desired.
                if not self._require_backward_grad_sync:
                    continue

                # Parameter and gradient devices must match.
                if hasattr(p, "_cpu_grad"):
1779
                    p_assert(p.device == torch.device("cpu"), f"WFPB: incorrect cpu_grad device {p.device}")
1780
1781
                    p.grad = p._cpu_grad
                elif hasattr(p, "_saved_grad_shard"):
1782
1783
1784
1785
                    p_assert(
                        p.device == p._saved_grad_shard.device,
                        f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
                    )
1786
1787
1788
1789
                    p.grad = p._saved_grad_shard

                if hasattr(p, "_saved_grad_shard"):
                    delattr(p, "_saved_grad_shard")
1790
1791

        # Update root and nested FSDP's hooks and flags.
1792
1793
        for m in self.modules():  # includes self
            if isinstance(m, FullyShardedDataParallel):
1794
                _finalize_parameters(m)
1795
                m._free_ssd_offload()
1796
                m._pre_backward_hook_has_run = False
1797
                if any(p.requires_grad for p in m.parameters()):
1798
1799
1800
1801
1802
                    # Check if the module has params and if any of them has
                    # the `requires_grad` field set. If `requires_grad=False` for
                    # all the params, the post_backward hook will not fire and the
                    # state will remain in `TrainingState.BACKWARD_PRE`.
                    if any([p.requires_grad for p in m.params]):
1803
1804
                        m.assert_state(TrainingState.BACKWARD_POST)
                    else:
1805
                        m.assert_state(TrainingState.BACKWARD_PRE)
1806
                else:
1807
1808
1809
1810
1811
1812
1813
                    # When `m` and its children has no params or has params but
                    # none with `requires_grad==True`, there are two cases:
                    # 1. output tensors are `requires_grad==True`. In this case,
                    # pre-backward hook is still registered, so it is in BACKWARD_PRE state.
                    # 2. output tensors are `requires_grad==False`. In this case,
                    # pre-backward hook is not registered, so it is in IDLE state.
                    m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
1814
                m.training_state = TrainingState.IDLE
1815

1816
1817
1818
                if m._is_root:
                    # reset this flag for cases like "one forward pass + multiple backward passes"
                    self._post_backward_callback_queued = False
1819
                    # clear this list for next iteration
1820
1821
1822
1823
1824
                    p_assert(
                        self._output_pre_backward_hook_registered is not None,
                        "WFPB: self._output_pre_backward_hook_registered should not be None",
                    )
                    assert self._output_pre_backward_hook_registered is not None  # make mypy happy
1825
                    self._output_pre_backward_hook_registered.clear()
1826

1827
    @torch.no_grad()
1828
    def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
1829
1830
1831
        """
        Gather all shards of params.

1832
1833
1834
        Note, this is idempotent if full params are already gathered. Callers
        assume the idempotency. So please keep it that way.

1835
        Args:
1836
1837
            force_full_precision (bool, Optional): by default params will be gathered
                in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
1838
                ``True``, in which case they will be gathered in full precision
1839
1840
                (e.g., FP32), possibly in fresh storage. The parameter that's being
                rebuilt will end up in full precision as well.
1841
1842

        Returns:
1843
            A list of tuples, where the first element is the full-sized param
1844
            and the second element is a bool indicating if it's safe for the
1845
            caller to free the full-sized param. This will be ``None`` if
1846
            ``force_full_precision=False`` and the full params are already gathered.
1847
1848
        """
        output_tensors: List[Tuple[torch.Tensor, bool]] = []
1849
1850

        def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
1851
1852
1853
1854
1855
1856
1857
            """
            Helper function to update p.data pointer.

            Args:
                custom_output_tensor (torch.Tensor, Optional): if not None, this
                tensor contains the data we just gathered.
            """
1858
1859
1860
1861
1862
            if custom_output_tensor is not None:
                assert p._is_sharded
                p.data = custom_output_tensor
                output_tensors.append((p.data, True))
            elif not p._is_sharded:
1863
                if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
1864
                    assert p._fp16_shard is not None
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
                    p.data = p._fp16_shard
                    output_tensors.append((p.data, True))
                else:
                    # Here p.data == p._fp32_shard, so it's not safe to free.
                    output_tensors.append((p.data, False))
            else:
                p.data = p._full_param_padded
                output_tensors.append((p.data, True))
            # Trim any padding and reshape to match original size.
            p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

1876
        if self.ssd_offload:
1877
1878
1879
            for p in self.params:
                assert isinstance(p, SsdFlatParameter)
                p.to_tensor()
1880
1881
1882

            self.has_full_params = False

1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
        if self._has_shared_params:
            # self.has_full_params flag can be out of sync if a shared param is
            # sharded by another FSDP instance. An example is that in eval case
            # with reshard_after_forward=False but the sharing instance has
            # reshard_after_forward=True. Then, on the second forward, the
            # other instance can shard the shared param and but this instance
            # can mistakenly think the full param is already gathered from the
            # has_full_params flag.
            #
            # Therefore, we update the flag accordingly here.
            self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)

1895
        # Early exit if we already have full params and don't need full precision.
1896
        if self.has_full_params and not force_full_precision:
1897
1898
1899
1900
1901
1902
            for p in self.params:
                update_p_data()
            return output_tensors

        self.has_full_params = True

1903
        with torch.cuda.stream(self._streams["all_gather"]):
1904
            if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
1905
1906
                self._cast_fp32_param_shards_to_fp16()

1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
            if self.move_params_to_cpu:
                if force_full_precision:
                    # If the compute_dtype and storage dtype are the same,
                    # use pinned memory. Otherwise move p.data to the compute
                    # device.
                    if self.params[0].dtype == self.compute_dtype:
                        self._cast_fp32_param_shards_to_fp16()
                    else:
                        for p in self.params:
                            p.data = p.data.to(self.compute_device)

1918
            for p in self.params:
1919
                if not p._is_sharded:  # e.g., when world_size == 1
1920
                    update_p_data()
1921
                else:
1922
1923
1924
1925
1926
                    # Skip if already built. Only shared param can be rebuilt multiple times.
                    # A corner case is p._orig_size = (1,), which means the shape equality is
                    # not a perfect check. But we assume we don't share a param with shape (1,).
                    if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
                        continue
1927
                    # If self.move_params_to_cpu and force_full_precision, we need to cast
1928
                    # the FP32 CPU param to CUDA for the all-gather.
1929
                    p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
1930
1931
1932

                    p_size = p._full_param_padded.size()
                    assert p_size.numel() % self.world_size == 0
1933
1934
1935
1936
1937
                    if self.mixed_precision and force_full_precision:
                        # Allocate fresh tensor in full precision since we are in
                        # mixed precision and full precision rebuild is asked.
                        output_tensor = p_data.new_zeros(p_size)
                    else:
1938
1939
1940
1941
                        if p._full_param_padded.storage().size() != p_size.numel():
                            # Allocate based on full size from all shards.
                            alloc_storage_(p._full_param_padded, size=p_size)
                        output_tensor = p._full_param_padded
1942

1943
                    # Fill output_tensor with (p.data for each shard in self.world_size)
1944
                    if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
1945
                        # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
1946
                        dist._all_gather_base(output_tensor, p_data, group=self.process_group)
1947
1948
1949
                    else:
                        chunks = list(output_tensor.chunk(self.world_size))
                        dist.all_gather(chunks, p_data, group=self.process_group)
1950

1951
1952
                    # Set p.data = output_tensor (with padding trimmed)
                    update_p_data(output_tensor)
1953

1954
1955
1956
1957
                    if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
                        self._free_fp16_param_shard([p])

                    if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
1958
                        self._free_fp16_param_shard([p])
1959

1960
        torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
1961
        return output_tensors
1962
1963
1964

    @torch.no_grad()
    def _use_full_params(self) -> None:
1965
1966
        """
        Switch p.data pointers to use the full params.
1967

1968
        Note: this assumes full params are already gathered.
1969
1970
1971

        Note: this might be called after full_params is already in used. So please
              make sure it is idempotent in that case.
1972
        """
1973
        assert self.has_full_params
1974
1975
        for p in self.params:
            if not p._is_sharded:
1976
                if self.mixed_precision or self.move_params_to_cpu:
1977
                    assert p._fp16_shard is not None
1978
1979
1980
                    assert p._fp16_shard.storage().size() != 0
                    p.data = p._fp16_shard
            else:
1981
                assert p._full_param_padded.storage().size() != 0, f"{p._orig_size} {id(self)}"
1982
1983
1984
1985
                p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)

    @torch.no_grad()
    def _prep_grads_for_backward(self) -> None:
1986
1987
        """Make sure p.grad is correctly prepared for the backward with
        right shape, device, accumulated values, etc.
1988
        """
1989
        for p in self.params:
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
            if p.grad is not None:
                if p.grad.device != p.data.device:
                    p.grad = None
                elif p.grad.size() == p._orig_size:
                    # This is gradient accumulation with no_sync context.
                    pass
                elif p.grad.size() == p._fp32_shard.shape:
                    # This is gradient accumulation without no_sync context.
                    # We save the grad shard and set p.grad to None for this backward pass.
                    # We will accumulate after this pass's grad is generated and reduced and
                    # sharded.
                    p._saved_grad_shard = p.grad.data
                    p.grad = None
                else:
                    raise AssertionError(f"unexpected grad shape: {p.grad.size()}")
2005
2006
2007
2008
2009
2010

    @torch.no_grad()
    def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
        """Free up storage for full parameters."""
        if params is None:
            params = self.params
2011
        self.has_full_params = False
2012
2013
2014
        current_stream = torch.cuda.current_stream()
        for p in params:
            if not p._is_sharded:  # e.g., world_size == 1
2015
                if self.mixed_precision or self.move_params_to_cpu:
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
                    self._free_fp16_param_shard([p])
                continue
            # Don't let PyTorch reuse this memory until all work in the current
            # stream is complete.
            p._full_param_padded.record_stream(current_stream)
            # There may be external references to the Tensor Storage that we
            # can't modify, such as references that are created by
            # ctx.save_for_backward in the forward pass. Thus when we
            # unshard parameters, we should reuse the original Tensor
            # Storage object and unshard it in-place. For now, just resize
            # the Storage to 0 to save memory.
            free_storage_(p._full_param_padded)
2028

2029
2030
2031
2032
    def local_metadata_dict(self) -> Dict[str, Any]:
        """
        Get the information needed to reconstruct the model from shards offline.

2033
2034
2035
        See the `consolidate_shard_weights` method below.
        """
        param_metadata = []
2036
        for path, m in self.named_modules():
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
            if isinstance(m, FullyShardedDataParallel):
                metadata: Dict[str, Any] = {}
                metadata["fsdp_path"] = _clean_path(path)
                metadata["params"] = {}

                metadata["no_broadcast_optim_state"] = m.no_broadcast_optim_state
                shared_param_info = []
                for (mpath_dst, mpath_src, _, src_name, _, dst_name) in m._shared_param_infos:
                    src_param_path = _clean_path(mpath_src + "." + src_name if mpath_src else src_name)
                    dst_param_path = _clean_path(mpath_dst + "." + dst_name if mpath_dst else dst_name)
                    shared_param_info.append((src_param_path, dst_param_path))
                metadata["shared_param_info"] = shared_param_info

                for i, p in enumerate(m.params):
                    if i < m._num_flatten_params:
                        backing_param_name = m.module.flat_param_names[i]
                        names, shapes, numels = m.module.metadata(i)
                    else:
                        assert len(m._param_name_groups[i]) == 1
                        backing_param_name = m._param_name_groups[i][0]
                        names = [backing_param_name]
                        shapes = [p._orig_size]
                        numels = [p._orig_size.numel()]
                    backing_param_name = _clean_path(backing_param_name)
                    metadata["params"][backing_param_name] = {
                        "names": [_clean_path(n) for n in names],  # A list of str.
                        "shapes": shapes,  # A list of torch.Size.
                        "numels": numels,  # A list of int.
                        "padding": m.numel_padded_per_param[i],  # An int for padding added to the backing parameter.
2066
                    }
2067
                param_metadata.append(metadata)
2068
2069

        buffer_names = [_clean_path(buffer_name) for buffer_name, _ in self.named_buffers(recurse=True)]
2070
        return dict(param_metadata=param_metadata, buffer_names=buffer_names)
2071
2072
2073
2074
2075
2076

    @staticmethod
    def consolidate_shard_weights(
        shard_weights: List[Dict[str, torch.Tensor]],
        shard_metadata: List[Dict[str, Any]],
        with_module_buffers: bool = True,
2077
        strict: bool = True,
2078
2079
2080
    ) -> Dict[str, torch.Tensor]:
        """
        Given a list of weights and meta data associated to N shards, reconstruct
2081
        the weights of an equivalent consolidated (non-sharded) state dict.
2082
2083
2084
2085
2086
2087

        Module parameters are consolidated using the shard metadata.

        Module buffers are taken from shard 0: this assumes that module buffers
        are either synchronized or that the shard 0 value is valid for all shards.
        If this behavior is not correct for your module (for instance if buffers
2088
        needs to be all-reduced instead), you can disable it with `with_module_buffers=False`.
2089

2090
        This method is used to re-assemble checkpoints of shards without
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
        having to instantiate FSDP wrappers with the world size (i.e. large
        number of GPUs) originally used to save the shards.

        Args:
            shard_weights (List[Dict[str, torch.Tensor]]):
                List of dictionaries that contains sharded weights from
                each rank.
            shard_metadata (List[Dict[str, Any]]):
                List of dictionaries that contains metadata from each shard.
                See `local_metadata_dict` above.
            with_module_buffers (bool):
                If shard 0's buffer should be returned in the consolidated
                weight dict.
                Default: True.
            strict (bool):
                allow incomplete shard weights. if True, every key in the metadata must be present in the weights.

2108
2109
        """
        if len(shard_weights) != len(shard_metadata) or not len(shard_weights):
2110
            raise ValueError("Require metadata for each shard and non-empty shards")
2111
2112
2113
2114

        consolidated_weights = {}
        original_world_size = len(shard_weights)

2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
        # For every FSDP instance.
        for fsdp_obj_idx, metadata in enumerate(shard_metadata[0]["param_metadata"]):
            fsdp_path = metadata["fsdp_path"]
            params = metadata["params"]
            # For every this-FSDP-owned param, flattened or not.
            for backing_param_name, v in params.items():
                in_state_dict_key = ".".join([fsdp_path, backing_param_name]) if fsdp_path else backing_param_name
                # Get full param back with pad removed.
                if in_state_dict_key not in shard_weights[0] and (not strict):
                    continue
2125
2126
                shards = []
                for rank in range(original_world_size):
2127
2128
                    shard = shard_weights[rank][in_state_dict_key]
                    pad = shard_metadata[rank]["param_metadata"][fsdp_obj_idx]["params"][backing_param_name]["padding"]
2129
                    shards.append(_unpad(shard, pad))
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
                    if metadata["no_broadcast_optim_state"]:
                        break
                full_param = torch.cat(shards, dim=0)
                # (Potentially), split the full param and create original params.
                names, shapes, numels, _ = v.values()
                assert sum(numels) == full_param.size(0)
                for n, t, s in zip(names, full_param.split(numels), shapes):
                    out_state_dict_key = ".".join([fsdp_path, n]) if fsdp_path else n
                    consolidated_weights[out_state_dict_key] = t.view(s)

        # copy shared parameters
        for src_path, dest_path in metadata["shared_param_info"]:
            consolidated_weights[dest_path] = consolidated_weights[src_path]
2143
2144
2145
2146

        # Deal with the buffers, which are not parameters and are not sharded by FSDP
        # and therefore are replicated among the different shards.
        # We take the values of the first shard (this assumes that there is some form
2147
        # of synchronization between shards or that all shards buffers are equivalent).
2148
2149
        if with_module_buffers:
            for buffer_name in shard_metadata[0]["buffer_names"]:
2150
2151
                if buffer_name not in shard_weights[0] and (not strict):
                    continue
2152
2153
2154
2155
                consolidated_weights[buffer_name] = shard_weights[0][buffer_name]

        return consolidated_weights

2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
    @torch.no_grad()
    def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
        """Use FP32 shard for a list of params."""
        if params is None:
            params = self.params
        for p in params:
            p.data = p._fp32_shard

    @torch.no_grad()
    def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
        """Cast FP32 param shard to FP16 for a list of params."""
        if params is None:
            params = self.params
        with torch.cuda.stream(self._streams["fp32_to_fp16"]):
            for p in params:
                assert p._fp16_shard is not None
                alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
                p._fp16_shard.copy_(
2174
2175
                    # If move_params_to_cpu is True, this will be non-blocking
                    # because _fp32_shard is pinned, otherwise it's a no-op.
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
                    p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
                )
                p.data = p._fp16_shard
        torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

    @torch.no_grad()
    def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
        """Free storage for FP16 shards for a list of params."""
        if params is None:
            params = self.params
        current_stream = torch.cuda.current_stream()
        for p in params:
            if p._fp16_shard is not None:
2189
                # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
2190
2191
2192
2193
                # free it until the work in the current stream completes.
                p._fp16_shard.record_stream(current_stream)
                free_storage_(p._fp16_shard)

2194
    def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
2195
        """Assert we are in the given state."""
2196
2197
2198
2199
2200
2201
2202
2203
2204
        # Since assert can be turned off and this error checking
        # is really important, we use explicit error checking
        # and raise a ValueError if needed.
        if isinstance(state, TrainingState):
            state = [state]
        if self.training_state not in state:
            msg = f"expected to be in states {state} but current state " f"is {self.training_state}"
            # In case we are failing in the context of autograd hook, asserting
            # may not generate useful msg. So, let's print it to be sure.
Min Xu's avatar
Min Xu committed
2205
            if self.rank == 0:
2206
2207
                print(f"Asserting FSDP instance is: {self}")
                print(f"ERROR: {msg}")
Min Xu's avatar
Min Xu committed
2208
                traceback.print_stack()
2209
            raise ValueError(msg)
2210

2211
    def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
2212
        """Collect [x.numel_padded_per_param for x in self._fsdp_instances] from each rank."""
2213
        world_pad_info: List[List[List[int]]] = []  # this will contain values from the whole world.
2214
        my_pad_info: List[List[int]] = [cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances]
2215
2216
        for rank in range(self.world_size):
            if rank == self.rank:
2217
                pad_info = my_pad_info
2218
            else:
2219
2220
                pad_info = [[0]] * len(my_pad_info)
            dist.broadcast_object_list(pad_info, src=rank, group=self.process_group)
2221
            if self.rank == 0:
2222
                world_pad_info.append(pad_info)
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
        return world_pad_info

    def _gather_optim_state(
        self, sd_state: Dict[int, Dict[str, Any]]
    ) -> Tuple[Dict[int, Dict[str, List]], Dict[int, Dict[str, List]]]:
        """For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
        gathered_state: Dict[int, Dict[str, List[Any]]] = {}
        singleton_state: Dict[int, Dict[str, List[Any]]] = {}  # Dimensionless tensor
        for k, v in sd_state.items():
            gathered_state[k] = {}
            singleton_state[k] = {}
2234
2235
2236
2237
2238
2239
2240
2241
            # For shared params, we are not flattening. We have only 1 non-shared
            # param that has the optimizer state. So we handle it with the correct
            # parameter list.
            non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
            assert (
                len(non_shared_params) == 1
            ), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
            desired_buffer_size = non_shared_params[0]._full_param_padded.size()
2242
2243
2244
            buffer = None  # for sharded tensors
            singleton_buffer = None  # for singleton tensors
            for buffer_name, t in v.items():
2245
2246
2247
                if torch.is_tensor(t):
                    t = t.to(self.compute_device)

2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
                if ou.is_singleton_tensor(t):
                    if singleton_buffer is None:
                        singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size))
                    dist.all_gather(singleton_buffer, t, group=self.process_group)
                    if self.rank == 0:
                        singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer]
                        assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0])
                elif torch.is_tensor(t):
                    if buffer is None:
                        buffer = list(t.new_zeros(*desired_buffer_size).chunk(self.world_size))
                    dist.all_gather(buffer, t, group=self.process_group)
                    if self.rank == 0:
                        gathered_state[k][buffer_name] = [x.cpu() for x in buffer]
                elif self.rank == 0:  # Add non tensor state
                    gathered_state[k][buffer_name] = [t]

        return gathered_state, singleton_state

    def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored: Dict) -> Optional[Dict[str, Any]]:
2267
2268
2269
2270
        """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
        sharded properties are not exposed. Multiple parameter groups are not yet supported.

        This should be called only on the root FSDP instance.
2271
        Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
2272
2273

        Args:
2274
2275
            optim (Optimizer): an optimizer instance for this FSDP rank. Its state_dict is
                        used in the consolidation. However, its state is not modified.
2276
2277

        Returns:
2278
2279

            * A dict with four entries (On rank zero, other workers return ``None``)
2280
2281
                * state - a dict holding gathered optimization state, 1 entry per unflat parameter
                * param_groups - a dict containing the 1 parameter group
2282
2283
                * param_id_map - global (unflat) to local (flat) id mapping
                * uncollected_local_ids - keys in the state dict that were not broadcast
2284
2285
2286
2287

        """
        if not self.flatten_parameters:
            raise NotImplementedError("optim state dict requires flatten_parameters=True")
2288
2289
2290
2291
2292
2293
2294
2295
2296

        self._lazy_init()
        sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
        assert set(sd.keys()) == {"param_groups", "state"}, f'{set(sd.keys())} != {"param_groups", "state"}'
        assert len(sd["param_groups"]) == 1, "Param groups are not supported"
        # We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
        state, singleton_state = self._gather_optim_state(sd.pop("state"))
        pad_info = self._broadcast_pad_info_to_r0()
        if self.rank != 0:
2297
2298
            return None
        # Unify the shard states by concatenating tensors and unflattening params
2299
        new_state_dict = ou.build_unflat_state_dict(
2300
            self._fsdp_instances, pad_info, state, singleton_state, self.uncollected_opt_state, sd["param_groups"]
2301
2302
2303
        )
        self.uncollected_opt_state = {}
        assert "uncollected_local_ids" in new_state_dict
2304
2305
2306
        return new_state_dict

    @property
2307
    def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
2308
2309
2310
        """Returns all fsdp modules in self.modules() including self."""
        return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

2311
2312
2313
2314
    def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
        uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
        new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
        if self.rank == 0:
2315
2316
2317
2318
2319
2320
            # Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
            self.uncollected_opt_state = {
                k: recursive_copy_to_device(v, non_blocking=False, device=torch.device("cpu"))
                for k, v in osd["state"].items()
                if k in uncollected_ids
            }
2321
2322
2323
2324
2325

        pg = copy.deepcopy(osd["param_groups"])
        new_dct["param_groups"] = pg
        return new_dct

2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
    def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
        """Get the portion of the optimizer state dict associated with the shard

        This can be used to get the right sharded optimizer state to be loaded
        into the sharded optimizer for this FSDP rank.

        Args:
            full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint.

        Returns:
            (dict): a shard of the optimizer state.
        """
        # Assert nesting is the same as it was at save time
        instance_list = self._fsdp_instances
        ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
2341
        ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
2342
2343
        if self.flatten_parameters:
            full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
2344
2345
2346
2347
            assert len(full_optim_state_dict["state"]) in (
                0,
                len(instance_list),
            ), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
2348
2349
2350
2351

        # get the portion of dict associated with the shard, in place
        for id, s in full_optim_state_dict["state"].items():
            for k, v in s.items():
2352
                if torch.is_tensor(v) and id not in ids_not_to_shard:
2353
                    v_shard, _ = self._get_shard(v)
2354
2355
2356
2357
                elif isinstance(v, list) and ou.is_singleton_tensor(v[0]):
                    # if we are resuming on larger world size, take first entry
                    v_shard = v[0] if self.rank >= len(v) else v[self.rank]
                    assert ou.is_singleton_tensor(v_shard)
2358
                else:
2359
                    v_shard = v  # don't shard entries that are not tensors
2360
2361
2362
2363
                full_optim_state_dict["state"][id][k] = v_shard

        return full_optim_state_dict

2364
    def _print_r0(self, msg: str, restart: bool = False) -> None:
2365
        """Debugging utility to print memory usage stats nicely on rank 0"""
2366
2367
        if restart:
            self._tstart = time.time()
2368
2369
        if self.rank == 0:
            gb_denom = 1024 ** 3
2370
            logging.info(
2371
2372
2373
                f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
            )

2374
2375
2376
2377
2378
    # Note: This property will be deprecated in an upcoming release in favor of `move_params_to_cpu`.
    @property
    def cpu_offload(self) -> bool:
        return self.move_params_to_cpu

2379

2380
2381
2382
2383
2384
2385
2386
def p_assert(cond: Any, s: Any) -> None:
    """Used in backward context to make sure error is printed."""
    if not cond:
        print(s)
        raise AssertionError


2387
2388
def _get_default_cuda_device(module: nn.Module) -> torch.device:
    """Try to infer CUDA device from module parameters."""
2389
2390
2391
2392
2393
2394
2395
2396
    try:
        compute_device = next(module.parameters()).device
        if compute_device.type == "cuda":
            return compute_device
    except StopIteration:
        pass
    # Fall back to current CUDA device
    return torch.device("cuda")
2397
2398


2399
def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
2400
    """
2401
    Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not.
2402
    We also retain the requires_grad flag so that casting doesn't affect the autograd graph.
2403
    """
2404

2405
    def fn_fp16(x: torch.Tensor) -> torch.Tensor:
2406
        if x.dtype is torch.float32:
2407
2408
2409
2410
            y = x.half()
            if x.is_leaf:
                y.requires_grad = x.requires_grad
            return y
2411
2412
        return x

2413
2414
    def fn_fp32(x: torch.Tensor) -> torch.Tensor:
        if x.dtype is torch.float16:
2415
2416
2417
2418
            y = x.float()
            if x.is_leaf:
                y.requires_grad = x.requires_grad
            return y
2419
2420
2421
2422
2423
2424
        return x

    fn = fn_fp16 if to_fp16 else fn_fp32
    context = torch.no_grad() if no_grad else contextlib.suppress()
    with context:  # type: ignore
        return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442


def free_storage_(data: torch.Tensor) -> None:
    """Free underlying storage of a Tensor."""
    if data.storage().size() > 0:
        # Since we're modifying the Tensor's Storage directly, make sure the Tensor
        # is the sole occupant of the Storage.
        assert data.storage_offset() == 0
        data.storage().resize_(0)


@torch.no_grad()
def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
    """Allocate storage for a tensor."""
    if data.storage().size() == size.numel():  # no need to reallocate
        return
    assert data.storage().size() == 0
    data.storage().resize_(size.numel())
2443
2444
2445


def _post_state_dict_hook(
2446
2447
2448
2449
2450
    state_dict_on_rank_0_only: bool,
    module: FullyShardedDataParallel,
    state_dict: "OrderedDict[str, torch.Tensor]",
    prefix: str,
    *args: Any,
2451
) -> "OrderedDict[str, torch.Tensor]":
tmarkstrum's avatar
tmarkstrum committed
2452
2453
    # When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only
    # returns full state dict on rank 0 and return empty dict non-rank 0,
2454
2455
2456
2457
2458
    # which allow FullyShardedDataParallel to skip the GPU -> CPU copy on
    # non-rank 0 altogether and prevent OOM.
    if state_dict_on_rank_0_only and dist.get_rank() != 0:
        state_dict.clear()
        return state_dict
2459
2460
2461
2462
    # Assuming we are in a ``summon_full_params()`` context, we need to clone
    # each tensor so that it does not get freed (in-place) when the context
    # exits. At the same time, this hook can be called multiple times
    # recursively, so we need to make sure that we only clone each tensor at
2463
    # most once. Thus we add an attribute on the tensor called "_has_been_cloned"
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
    # which keeps track of tensors that are no longer at risk of being freed.
    for key in state_dict.keys():
        if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
            continue
        if state_dict[key].device.type != module.state_dict_device.type:
            state_dict[key] = state_dict[key].to(device=module.state_dict_device)
            state_dict[key]._has_been_cloned = True
        elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
            # We copy the state_dict since full param will be freed after we
            # exit the ``summon_full_params()`` context.
2474
            state_dict[key] = state_dict[key].clone()
2475
            state_dict[key]._has_been_cloned = True
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485

    # Remove "_fsdp_wrapped_module." prefix
    replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
    return state_dict


def _pre_load_state_dict_hook(
    state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
    replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
Min Xu's avatar
Min Xu committed
2486
2487


2488
def _clean_path(path: str) -> str:
2489
    """Remove FSDP related wrapper modules from a given state dict key str path."""
2490
2491
2492
2493
2494
2495
2496
2497
2498
    return ".".join([split for split in path.split(".") if split not in {"_fsdp_wrapped_module", "_fpw_module"}])


def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
    if pad > 0:
        shard = shard[:-pad]
    return shard


Min Xu's avatar
Min Xu committed
2499
2500
2501
2502
2503
########################################################################################
# Below are APIs used together with FSDP, but not directly part of FSDP.
########################################################################################


2504
2505
2506
2507
2508
def auto_wrap_bn(
    module: nn.Module,
    single_rank_pg: bool = False,
    process_group: Optional[ProcessGroup] = None,
    fsdp_config: Optional[Dict[str, Any]] = None,
2509
2510
    wrap_it: bool = True,
    assert_on_collision: bool = True,
2511
) -> nn.Module:
Min Xu's avatar
Min Xu committed
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
    """
    Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
    to sync BN is used and the outer FSDP is flattening.

    We put BN in is own full precision, unflatten, single GPU group FSDP.  Note, SyncBNs still have
    a group size == world_size. The input and output for BN are still FP16 in mixed precision mode.
    See ``keep_batchnorm_fp32`` here: https://nvidia.github.io/apex/amp.html

    This needs to be done at each rank, like models being wrapped by FSDP at each rank.

    Args:
        module (nn.Module):
            The model (or part of the model) in which BN to be pre-wrapped.
2525
2526
2527
        single_rank_pg (bool):
            If true, put BNs in a single-rank process group. Default False.
            This might be needed for Apex sync BN support. Still under construction.
2528
2529
2530
2531
        process_group (ProcessGroup):
            Optional process group to be used.
        fsdp_config (Dict):
            Optional fsdp_config to be used.
2532
2533
2534
2535
2536
2537
        wrap_it (bool):
            Whether or not wrap the module after setting the config.
            Default: True
        assert_on_collision (bool):
            Whether or not assert if a wrapper_config already exists on the module.
            Default: True
Min Xu's avatar
Min Xu committed
2538
2539
2540
2541

    Returns:
        Processed module, where BNs are wrapped with a special FSDP instance.
    """
2542
    # Prepare a fsdp_config dict for BNs.
2543
    pg = process_group
2544
2545
2546
    if single_rank_pg:
        # No sharding with this single member group.
        my_rank = dist.get_rank()
2547
        pg = get_process_group_cached(ranks=[my_rank])
2548

2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
    if fsdp_config is None:
        fsdp_config = {
            "process_group": pg,
            "mixed_precision": False,  # Keep the weights in FP32.
            "flatten_parameters": False,  # Do not flatten.
            # Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
            # **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
            # within the checkpoint's outer backward when multiple forward passes are used.
            "reshard_after_forward": False,
            # No bucketing or small bucketing should be enough for BNs.
            "bucket_cap_mb": 0,
            # Setting this for SyncBatchNorm. This may have a performance impact. If
            # SyncBatchNorm is used, this can be enabled by passing in the `fsdp_config` argument.
            "force_input_to_fp32": False,
        }
Min Xu's avatar
Min Xu committed
2564

2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
    # Assign the config dict to BNs.
    for m in module.modules():
        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
            if assert_on_collision:
                assert not hasattr(
                    m, "wrapper_config"
                ), "Module shouldn't already have a wrapper_config. Is it tagged already by another policy?"
            m.wrapper_config = fsdp_config

    # Wrap it.
    with (
        enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
    ):
Min Xu's avatar
Min Xu committed
2578
        return auto_wrap(module)