fully_sharded_data_parallel.py 56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
from enum import Enum, auto
import functools
from math import inf
11
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
12
13
14
15
16
17
18
19
20
21
22

import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.optim.utils import calc_grad_norm
23
from fairscale.utils.containers import apply_to_tensors
24
25
from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
26
from fairscale.utils.state_dict import replace_by_prefix_
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

if TYPE_CHECKING:
    from collections import OrderedDict  # noqa: F401


class TrainingState(Enum):
    """
    Simple enum to indicate what state FSDP is in. Used for asserting
    to make sure APIs are called in the correct state.

    TODO (Min): It would be nice to capture the stepping state as well.
        Maybe we can use the model.zero_grad() call, but not sure if it
        is called if optim.zero_grad() is used instead.
        It would be nice to have clear state transition be explicit like:

        zero_grad -> fwd -> bwd -> optionally accum grad by repeating
        fwd/bwd -> stepping -> loop back to zero_grad
    """

    IDLE = auto()
    FORWARD = auto()
    BACKWARD = auto()
49
    SUMMON_FULL_PARAMS = auto()
50
51
52
53
54
55
56
57
58
59
60
61


class FullyShardedDataParallel(nn.Module):
    """
    A wrapper for sharding Module parameters across data parallel workers. This
    is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.

    .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
    .. _DeepSpeed: https://www.deepspeed.ai/

    Usage::

Myle Ott's avatar
Myle Ott committed
62
        torch.cuda.set_device(device_id)
63
64
65
66
67
68
69
70
71
        sharded_module = FullyShardedDataParallel(my_module)
        optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
        x = sharded_module(x, y=3, z=torch.Tensor([1]))
        loss = x.sum()
        loss.backward()
        optim.step()

    It is also possible to shard individual layers separately and have an outer
    wrapper handle any leftover parameters. This can be helpful to further
Myle Ott's avatar
Myle Ott committed
72
73
74
    reduce GPU memory usage, reduce system memory usage when initializing large
    models and to improve training speed by overlapping the all-gather step
    across the forward pass. For example::
75
76

        sharded_model = FullyShardedDataParallel(
Myle Ott's avatar
Myle Ott committed
77
            nn.Sequential(  # doesn't have to be nn.Sequential
78
79
80
81
82
83
84
                nn.Linear(5, 100),
                FullyShardedDataParallel(nn.Linear(100, 100)),
                FullyShardedDataParallel(nn.Linear(100, 100)),
                nn.Linear(100, 5),
            )
        )

Myle Ott's avatar
Myle Ott committed
85
86
87
88
89
90
    .. warning::

        The optimizer must be initialized *after* the module has been wrapped,
        since FSDP will shard parameters in-place and this will break any
        previously initialized optimizers.

91
92
93
94
95
96
97
    .. warning::

        If you wrap every parameter inside a nested FSDP and leaving the outer
        FSDP empty without any parameter, checkpointing activation may trigger
        an assert on the backward pass. The solution is to leave some parameters
        to the outer FSDP.

98
    Args:
Min Xu's avatar
Min Xu committed
99
100
101
102
103
        module (nn.Module):
            module to checkpoint
        process_group (Optional):
            process group for sharding
        reshard_after_forward (bool, Optional):
Myle Ott's avatar
Myle Ott committed
104
105
106
            if ``True``, reshard parameters after the forward pass. This saves
            memory but slows training. This is only relevant when resharding
            individual layers.
Min Xu's avatar
Min Xu committed
107
        mixed_precision (bool, Optional):
Myle Ott's avatar
Myle Ott committed
108
109
110
            if ``True``, inputs, activations and gradients will be kept in FP16;
            computation and communication will occur in FP16; and a (sharded)
            master copy of the model weights will be maintained in FP32.
Min Xu's avatar
Min Xu committed
111
        fp32_reduce_scatter (bool, Optional):
Myle Ott's avatar
Myle Ott committed
112
113
            if ``True``, then reduce-scatter gradients in FP32. This is only
            relevant when *``mixed_precision``* is ``True``.
Min Xu's avatar
Min Xu committed
114
        flatten_parameters (bool, Optional):
Myle Ott's avatar
Myle Ott committed
115
116
            if ``True``, flatten parameters into a single contiguous tensor,
            which improves training speed.
Min Xu's avatar
Min Xu committed
117
        cpu_offload (bool, Optional):
Myle Ott's avatar
Myle Ott committed
118
119
            if ``True``, offload FP32 params to CPU. This is only relevant when
            *``mixed_precision``* is ``True``.
Min Xu's avatar
Min Xu committed
120
        compute_dtype (torch.dtype, Optional):
Myle Ott's avatar
Myle Ott committed
121
122
123
            dtype for full parameters for computation. This defaults to
            ``torch.float32`` unless *``mixed_precision``* is set, in which case
            it defaults to ``torch.float16``.
124
125
        buffer_dtype (torch.dtype, Optional):
            dtype for buffers for computation. This defaults to ``compute_dtype``.
Min Xu's avatar
Min Xu committed
126
        move_grads_to_cpu (bool, Optional):
Myle Ott's avatar
Myle Ott committed
127
128
129
            move gradient shard to CPU after reduction. This is useful when
            combined with CPU-based optimizers. It defaults to the value of
            *``cpu_offload``*.
Min Xu's avatar
Min Xu committed
130
        bucket_cap_mb (int, Optional):
Myle Ott's avatar
Myle Ott committed
131
132
133
134
135
136
            FSDP will bucket parameters so that gradient reduction can
            potentially overlap with backward computation. bucket_cap_mb
            controls the bucket size in MegaBytes (MB). Buckets are sub-divided
            based on world_size, so the max shard size is roughly
            ``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
            Default: 25.
137
138
139
140
141
142
143
144
145
146
147
148
    """

    def __init__(
        self,
        module: nn.Module,
        process_group: Optional[ProcessGroup] = None,
        reshard_after_forward: bool = True,
        mixed_precision: bool = False,
        fp32_reduce_scatter: bool = False,
        flatten_parameters: bool = True,
        cpu_offload: bool = False,
        compute_dtype: Optional[torch.dtype] = None,
149
        buffer_dtype: Optional[torch.dtype] = None,
150
151
152
153
154
155
156
157
158
159
160
161
162
        move_grads_to_cpu: Optional[bool] = None,
        bucket_cap_mb: int = 25,
    ):
        super().__init__()
        self.process_group = process_group or dist.new_group()
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
        self.reshard_after_forward = reshard_after_forward
        self.mixed_precision = mixed_precision
        self.fp32_reduce_scatter = fp32_reduce_scatter
        self.flatten_parameters = flatten_parameters
        self.cpu_offload = cpu_offload
        self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
163
        self.buffer_dtype = buffer_dtype or self.compute_dtype
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
        self.bucket_cap_mb = bucket_cap_mb

        if self.fp32_reduce_scatter and not self.mixed_precision:
            raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
        if self.cpu_offload and not self.mixed_precision:
            raise ValueError("cpu_offload requires mixed_precision=True")

        compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
        validate_process_group(compute_device, self.process_group)

        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
        params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))

180
        self._has_params = len(params) > 0
181
182
183
184
        if not self._has_params:
            self.flatten_parameters = False

        if self.flatten_parameters:
185
            self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
186
            del module  # free original module in case it helps garbage collection
187
            self.params = [self._fsdp_wrapped_module.flat_param]
188
        else:
189
            self._fsdp_wrapped_module = module
190
191
192
193
194
195
196
197
198
199
200
201
202
            self.params = params

        # Shard module parameters in place
        self._shard_parameters_()

        # Make sure all parameters are sharded.
        for n, p in self.named_parameters():
            assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"

        self._reset_lazy_init()

        # Flag to indicate if we require gradient reduction in the backward
        # pass. This will be False when inside the no_sync context manager.
203
        self._require_backward_grad_sync: bool = True
204

205
        # Enum to indicate if we're in the forward/backward pass, idle, etc.
206
207
        self.training_state = TrainingState.IDLE

208
209
210
        # Flag to indicate if the full params are gathered.
        self.has_full_params: bool = False

211
212
213
214
215
216
217
218
219
220
221
222
223
224
        # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
        # prefix and before load_state_dict() to add it back.
        self._register_state_dict_hook(_post_state_dict_hook)
        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)

        # Flag to indicate whether state_dict() should automatically summon the
        # full params. This defaults to True, but may be set to False if the
        # user explicitly requests the local state dict via local_state_dict().
        self._return_full_state_dict = True

    @property
    def module(self) -> nn.Module:
        return self._fsdp_wrapped_module  # note: may be a FlattenParamsWrapper instance

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    @torch.no_grad()
    def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
        """Move all buffers to the specified device and dtype, recursively."""
        cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
        self.apply(cast_fn)

    @property
    def params_with_grad(self) -> List[Parameter]:
        """[p for p in self.parameters() if p.grad is not None] """
        return [p for p in self.parameters() if p.grad is not None]

    @torch.no_grad()
    def clip_grad_norm_(
        self,
        max_norm: Union[float, int],
        norm_type: Union[float, int] = 2.0,
        # filter_params_fn: Callable[[Any], Any] = None,
    ) -> torch.Tensor:
        """
Myle Ott's avatar
Myle Ott committed
244
245
246
        Clip all gradients at this point in time. The norm is computed over all
        gradients together, as if they were concatenated into a single vector.
        Gradients are modified in-place.
247

Myle Ott's avatar
Myle Ott committed
248
        Args:
249
            max_norm (float or int): max norm of the gradients
Myle Ott's avatar
Myle Ott committed
250
251
            norm_type (float or int): type of the used p-norm. Can be ``'inf'``
                for infinity norm.
252
253
254
255

        Returns:
            Total norm of the parameters (viewed as a single vector).

Myle Ott's avatar
Myle Ott committed
256
257
258
259
260
261
        .. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
            handles the partitioning and multiple devices per rank under the
            hood. The default torch util is not applicable here, because each
            rank only has a partial view of all the grads in the model, so
            calling it in the OSS context would lead to different scaling being
            applied per subset of model parameters.
262

Myle Ott's avatar
Myle Ott committed
263
264
        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
265
        """
266
267
268
269
        # We don't call torch.cuda.synchronize() here, since clipping can be
        # inside the train loop and we probably don't want to force a GPU-CPU sync.
        # _lazy_init should be sufficient, since it will force the other streams
        # to sync with the default stream (via _wait_for_previous_optim_step).
270
        self._lazy_init()
271
        assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
272
        self.assert_state(TrainingState.IDLE)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

        max_norm = float(max_norm)
        norm_type = float(norm_type)
        params_with_grad = self.params_with_grad
        if not self.children_share_process_group:
            raise NotImplementedError(
                "clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work"
            )
        # Computes the max norm for this shard's gradients and sync's across workers
        local_norm = calc_grad_norm(params_with_grad, norm_type).cuda()
        if norm_type == inf:
            total_norm = local_norm
            dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
        else:
            total_norm = local_norm ** norm_type
            dist.all_reduce(total_norm, group=self.process_group)
            total_norm = total_norm ** (1.0 / norm_type)

        if self.move_grads_to_cpu:
            total_norm = total_norm.cpu()
        # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
        clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
        if clip_coef < 1:

            # multiply by clip_coef
            for p in params_with_grad:
                p.grad.detach().mul_(clip_coef.to(p.grad.device))  # type: ignore

        return total_norm

    @torch.no_grad()
    def _shard_parameters_(self) -> None:
        """
        At initialization we wrap a module with full parameters and shard the
        parameters in-place. Sharding is implemented by viewing each parameter
        as a 1D Tensor and retaining only a single slice, where the slice size
        is determined by the number of data parallel workers.

        Wrapping modules with many small parameters (or with a very large data
        parallel world size) will result in many small parameter shards and slow
        performance. In this case it's better to set *``flatten_parameters``* to
        ``True``, so that all of the small parameters in the module are combined
        into a single contiguous Tensor and sharded once.

        After this initial sharding is complete, the user can initialize a
        ``torch.optim.Optimizer`` in the usual way, i.e.::

        .. code-block:: python

            optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)

        The optimizer will see only a single slice of parameters and will thus
        allocate less memory for optimizer state, avoiding redundancy across
        data parallel workers.
        """
        for p in self.params:
            assert not hasattr(p, "_is_sharded")
            assert p.is_floating_point()
            if self.mixed_precision:
                assert p.dtype == torch.float32

            # If world_size is 1, then we all-reduce grads instead of sharding.
            p._is_sharded = self.world_size > 1
            p._orig_size = p.data.size()

            if not p._is_sharded:
                continue
            p._is_sharded = True

            # Replace p.data with the relevant shard.
            orig_data = p.data
344
            p.data = self._get_shard(p.data)
345
346
            free_storage_(orig_data)

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
        """Return the local shard of a given full tensor."""
        # Shard using torch.chunk to match all-gather/reduce-scatter.
        chunks = list(torch.flatten(tensor).chunk(self.world_size))
        while len(chunks) < self.world_size:
            chunks.append(chunks[0].new_empty(0))

        # Determine number of padding elements.
        num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
        assert num_to_pad >= 0, num_to_pad

        shard = chunks[self.rank].clone()
        if num_to_pad > 0:
            shard = F.pad(shard, [0, num_to_pad])
        return shard

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def extra_repr(self) -> str:
        return (
            f"rank={self.rank}, world_size={self.world_size}, "
            f"reshard_after_forward={self.reshard_after_forward}, "
            f"mixed_precision={self.mixed_precision}, "
            f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
            f"flatten_parameters={self.flatten_parameters}, "
            f"cpu_offload={self.cpu_offload}, "
            f"compute_dtype={self.compute_dtype}, "
            f"move_grads_to_cpu={self.move_grads_to_cpu}"
        )

    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.module, name)

    def __getstate__(self) -> Dict[str, str]:
        """Serialize the state of the current FullyShardedDataParallel instance.

        Some properties are not serializable (e.g., process groups, streams), so
        we remove them and try to reconstruct them in :func:`__setstate__`.
        """
        state = copy.copy(self.__dict__)
        state["is_sharded"] = [p._is_sharded for p in self.params]
        state["orig_sizes"] = [p._orig_size for p in self.params]
        if state["process_group"] is not None:
            state["process_group"] = "MISSING"  # process_group isn't pickleable
        self._reset_lazy_init()
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Intercept state setting and perform needed changes on params."""
        super().__setstate__(state)

        def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
            assert isinstance(p, Parameter)
            p.data = p.data.clone()  # move tensors out of shared memory
            p._is_sharded = is_sharded
            p._orig_size = size
            return p

        self.params = [
            fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
        ]
        del self.is_sharded
        del self.orig_sizes
        self._reset_lazy_init()

    # TODO (Min): figuring out how to do typing for this overloaded function.
415
    def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]":  # type: ignore
416
417
418
        """
        Returns the whole (unsharded) state of the module. Parameters are not
        sharded, so the resulting state_dict can be loaded directly by the
Myle Ott's avatar
Myle Ott committed
419
        wrapped Module without any sharding-specific logic. Returned tensors
420
        will be full precision (e.g., FP32).
Myle Ott's avatar
Myle Ott committed
421
422
423

        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
424
        """
425
426
        torch.cuda.synchronize()
        self._lazy_init()
427
        if self.mixed_precision:
428
429
430
            # Buffers dtype stays consistent with parameters.
            self._all_buffers_to(dtype=torch.float32)

431
432
        if self._return_full_state_dict:
            if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
433
                with self.summon_full_params(volatile=True):
434
435
436
437
438
439
440
441
442
                    state_dict = super().state_dict(*args, **kwargs)
            else:
                state_dict = super().state_dict(*args, **kwargs)
        else:
            if self.flatten_parameters:
                assert isinstance(self.module, FlattenParamsWrapper)
                state_dict = self.module.flat_state_dict(*args, **kwargs)
            else:
                state_dict = super().state_dict(*args, **kwargs)
443

444
445
446
447
        if self.cpu_offload:
            for k in state_dict.keys():
                state_dict[k] = state_dict[k].cpu()

448
        if self.mixed_precision:
449
            # In case we are in mixed precision, restore buffers back to fp16.
450
            self._all_buffers_to(dtype=self.buffer_dtype)
451
452
453
454
455
456
457
458
459
        return state_dict

    # TODO (Min): figuring out how to do typing for this overloaded function.
    def local_state_dict(self, *args, **kwargs):  # type: ignore
        """
        Returns the local (sharded) state of the module. Parameters are sharded,
        so the resulting state_dict can only be loaded after the Module has been
        wrapped with FullyShardedDataParallel.
        """
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        with contextlib.ExitStack() as stack:
            # Tell any nested FSDP instances not to auto summon full params.
            for module in self.modules():  # includes self
                if isinstance(module, FullyShardedDataParallel):
                    stack.enter_context(module._no_return_full_state_dict())
            return self.state_dict(*args, **kwargs)

    @contextlib.contextmanager
    def _no_return_full_state_dict(self) -> Generator:
        backup = self._return_full_state_dict
        self._return_full_state_dict = False
        try:
            yield
        finally:
            self._return_full_state_dict = backup
475
476
477
478

    def load_state_dict(
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
Myle Ott's avatar
Myle Ott committed
479
480
481
482
483
484
        """
        Load a whole (unsharded) state_dict.

        .. warning:: This needs to be called on all ranks, since synchronization
            primitives will be used.
        """
485
486
487
488
489
490
491
        if self._return_full_state_dict:
            with self.summon_full_params():
                return self.module.load_state_dict(state_dict, strict)
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            return self.module.load_state_dict(state_dict, strict)
492
493
494
495
496

    def load_local_state_dict(
        self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
    ) -> NamedTuple:
        """Load a local (sharded) state_dict."""
497
498
499
500
501
502
503
        with contextlib.ExitStack() as stack:
            # Tell any nested FSDP instances not to auto summon full params.
            for module in self.modules():  # includes self
                if isinstance(module, FullyShardedDataParallel):
                    stack.enter_context(module._no_return_full_state_dict())
            output = self.load_state_dict(state_dict, strict)
        return output
504
505
506
507
508
509
510

    @contextlib.contextmanager
    def no_sync(self) -> Generator:
        """
        A context manager to disable gradient synchronizations across DDP
        processes. Within this context, gradients will be accumulated on module
        variables, which will later be synchronized in the first
511
512
513
514
515
        forward-backward pass after exiting the context.

        .. note:: This may result in higher memory usage because we will
            accumulate the full model gradients (instead of gradient shards)
            until the eventual sync.
516
517
518
519
520
521
522
523
524
        """
        self._lazy_init()
        assert self._is_root, "no_sync on inner FSDP is not supported"
        self.assert_state(TrainingState.IDLE)
        # This instance may wrap other FullyShardedDataParallel instances and we
        # need to set all of them to accumulate gradients.
        old_flags = []
        for m in self.modules():  # includes self
            if isinstance(m, FullyShardedDataParallel):
525
526
                old_flags.append((m, m._require_backward_grad_sync))
                m._require_backward_grad_sync = False
527
528
529
530
        try:
            yield
        finally:
            for m, old_flag in old_flags:
531
                m._require_backward_grad_sync = old_flag
532

533
    @contextlib.contextmanager
534
    def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
535
        """
536
537
        A context manager to expose full params for the current FSDP instance.
        Can be useful *after* forward/backward for a model to get the params for
538
539
        additional processing or checking. Parameters will be gathered in full
        precision (e.g., FP32).
540

541
        .. note:: This can be used on inner FSDPs.
542

543
544
        .. note:: This can *not* be used within a forward or backward pass. Nor
            can forward and backward be started from within this context.
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

        .. note:: The full parameters will be freed after the context manager
            exits; it is up to the caller to clone them if needed.

        .. note:: The full parameters can be modified, but only the portion
            corresponding to the local param shard will persist after the
            context manager exits (unless ``volatile=True``, in which case there
            are no guarantees about persistence).

        Args:
            recurse (bool, Optional): recursively summon all params for nested
                FSDP instances (default: True)
            volatile (bool, Optional): if ``True``, modifications to params are
                not guaranteed persist after the context manager exists;
                enabling this can be slightly more efficient (default: False)
560
        """
561
562
        if recurse:
            with contextlib.ExitStack() as stack:
563
                # Summon all params for any nested FSDP instances.
564
565
                for module in self.modules():
                    if isinstance(module, FullyShardedDataParallel):
566
567
                        stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
                # Yield to the caller, with full params in all nested instances.
568
                yield
569
            # Exiting from the ExitStack will re-shard params.
570
571
572
573
574
575
576
577
            return
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            self.assert_state(TrainingState.IDLE)
            # Set the state so that we assert when trying to go into
            # forward/backward.
            self.training_state = TrainingState.SUMMON_FULL_PARAMS
578
            full_tensors = self._rebuild_full_params(full_precision=True)
579
            assert full_tensors is not None
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
            with contextlib.ExitStack() as stack:
                if self.flatten_parameters and self.module.is_flattened:
                    # Update flattened views to point to fully-sized tensors. We
                    # use self.params[0] instead of full_tensors since the
                    # latter may contain padding.
                    assert len(self.params) == 1
                    assert isinstance(self.module, FlattenParamsWrapper)
                    stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0]))
                try:
                    yield
                finally:
                    stack.close()
                    assert len(full_tensors) == len(self.params)
                    for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors):
                        if not volatile:
                            # Copy any changes made to the full params back into
                            # the corresponding local shards.
                            local_shard = self._get_shard(full_tensor)
                            p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
                        if safe_to_free:
                            free_storage_(full_tensor)
601
                    self.has_full_params = False
602
603
                    self._use_fp32_param_shard()
                    self.training_state = TrainingState.IDLE
604

605
606
607
    def _reset_lazy_init(self) -> None:
        """Reset instance so :func:`_lazy_init` will run on the next forward."""
        self._is_root: Optional[bool] = None
608
        self._queue_wait_for_post_backward_closure: Optional[Callable] = None
609
610
611
612
613
        self._streams: Dict[str, torch.cuda.Stream] = {}
        self._reducer: Optional[ReduceScatterBucketer] = None

    def _lazy_init(self) -> None:
        """Initialization steps that should happen lazily, typically right
614
615
           before the first forward pass.
        """
616
617
618
619
620
621
622
623
624
625
626
627
628
        # Initialize param attributes lazily, in case the param's dtype or
        # device changes after __init__.
        for p in self.params:
            self._init_param_attributes(p)

        # Initialize _is_root and setup streams. These steps would ideally
        # happen in __init__, but _is_root can only be determined after the
        # entire model hierarchy is setup, thus we run it lazily.
        if self._is_root is None:
            self._set_is_root()
            self._setup_streams()

        if self.cpu_offload:  # Buffers stay on GPU, and don't get sharded
629
            self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
630
        else:
631
            self._all_buffers_to(dtype=self.buffer_dtype)
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718

        if self._is_root:
            # Don't free the full params for the outer-most (root) instance,
            # since those params will be needed immediately after for the
            # backward pass.
            self.reshard_after_forward = False

            # Due to the use of streams, we need to make sure the previous
            # ``optim.step()`` is done before we all-gather parameters.
            self._wait_for_previous_optim_step()

    @torch.no_grad()
    def _init_param_attributes(self, p: Parameter) -> None:
        """
        We manage several attributes on each Parameter instance. The first two
        are set by :func:`_shard_parameters_`:

            ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
                if the Parameter is intentionally not sharded (in which case we
                will all-reduce grads for this param).
            ``_orig_size``: the size of the original Parameter (before sharding)

        The remaining attributes are set here:
            ``_fp32_shard``: a single shard of the parameters in full precision
                (typically FP32, but this is dependent on the dtype of the model
                as it's passed in by the user). This can be on CPU or GPU
                depending on the value of *``cpu_offload``*.
            ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be
                a single shard of the parameters in FP16, used for all-gather.
            ``_full_param_padded``: the full weight (padded to be evenly
                divisible by ``world_size``), used for computation in the
                forward and backward pass. This will be resized in place and
                only materialized (via all-gather) as needed.
        """
        assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size")
        if hasattr(p, "_fp32_shard"):
            return

        # Compute device defaults to CUDA when *cpu_offload* is enabled, or the
        # param's current device otherwise (could be CPU).
        compute_device = torch.device("cuda") if self.cpu_offload else p.device

        # A single shard of the parameters in full precision.
        p._fp32_shard = p.data

        if self.mixed_precision:
            assert p._fp32_shard.dtype == torch.float32

            if self.cpu_offload:
                assert p._fp32_shard.device == torch.device("cpu")
                # If we plan to keep the FP32 parameters on CPU, then pinning
                # memory allows us to later use non-blocking transfers when moving
                # the FP32 param shard to compute_device.
                p._fp32_shard = p._fp32_shard.pin_memory()
                p.data = p._fp32_shard

            # In mixed precision mode, we maintain a reduced precision
            # (typically FP16) parameter shard on compute_device for performing
            # the computation in the forward/backward pass. We resize the
            # storage to size 0 at init (here) and re-materialize (by copying
            # from _fp32_shard) as needed.
            p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype)
            free_storage_(p._fp16_shard)
        else:
            p._fp16_shard = None  # use _fp32_shard

        # We also maintain a full-sized parameter of type self.compute_dtype
        # (FP16 for mixed_precision or FP32 otherwise). We resize the
        # storage to size 0 at init (here) and only materialize as needed. The
        # storage may contain padding elements so that it is evenly divisible by
        # world_size, although these padding elements will be removed before the
        # relevant computation.
        if p._is_sharded:
            p._full_param_padded = torch.zeros(
                p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype
            )
            free_storage_(p._full_param_padded)

        if self.move_grads_to_cpu:
            # We can optionally move the grad shard to CPU during the backward
            # pass. In this case, it's important to pre-allocate the CPU grad
            # shard in pinned memory so that we can do a non-blocking transfer.
            p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()

    def _set_is_root(self) -> None:
        """If ``True``, implies that no other :class:`FullyShardedDataParallel`
        instance wraps this one. Called once by :func:`_lazy_init`.
Myle Ott's avatar
Myle Ott committed
719
720
721
722
        Also sets self.children_share_process_group = True if all child
        instances share the same process group. If some child instances use a
        different process group, self.clip_grad_norm_ will raise an error.
        """
723
724
        if self._is_root is not None:
            return
725
        # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
726
        self._is_root = True
727
728
729
730
        assert self._queue_wait_for_post_backward_closure is None
        self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
        # As the root, we now set all children instances to False and
        # give them a closure to try to queue a wait_for_post_backward.
731
732
        self.children_share_process_group = True
        for n, m in self.named_modules():
733
            # `n != ""` excludes self.
734
735
736
            if n != "" and isinstance(m, FullyShardedDataParallel):
                assert m._is_root is None
                m._is_root = False
737
738
739
740
741
742
743
744
                # When root instance doesn't have params, allow children instances
                # to queue the post_backward hook.
                #
                # TODO (Min): we should think if we can have a empty param at the root
                #             so that root always have a callback on the backward graph.
                if not self._has_params:
                    assert m._queue_wait_for_post_backward_closure is None
                    m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
                if m.process_group != self.process_group:
                    self.children_share_process_group = False

    def _setup_streams(self) -> None:
        """Create streams to overlap data transfer and computation."""
        if len(self._streams) > 0 or not self._is_root:
            return
        # Stream to move main FP32 params (may be on CPU) to FP16 for forward.
        self._streams["fp32_to_fp16"] = torch.cuda.Stream()
        # Stream for all-gathering parameters.
        self._streams["all_gather"] = torch.cuda.Stream()
        # Stream for overlapping grad reduction with the backward pass.
        self._streams["post_backward"] = torch.cuda.Stream()
        # Helper for bucketing reduce-scatter ops. This is also shared with
        # children instances to improve bucket utilization.
        self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
        # We share streams with all children instances, which allows them to
        # overlap transfers across the forward pass without synchronizing with
        # the default stream.
        for n, m in self.named_modules():
            if n != "" and isinstance(m, FullyShardedDataParallel):
                m._streams = self._streams
                m._reducer = self._reducer

    def _wait_for_previous_optim_step(self) -> None:
        """
        The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
        instance) needs to synchronize with the default stream to ensure the
        previous optimizer step is done.
        """
        if self.mixed_precision:
            self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
        else:
            self._streams["all_gather"].wait_stream(torch.cuda.current_stream())

    def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        self._lazy_init()

        # Start of a forward pass.
        self.training_state = TrainingState.FORWARD

786
        if self._is_root and self.mixed_precision:
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
            args, kwargs = cast_inputs_to_fp16(*args, **kwargs)

        # All-gather full parameters. This will also transfer FP32 parameters to
        # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
        self._rebuild_full_params()

        # Register backward hooks to reshard params and reduce-scatter grads.
        # These need to be re-registered every forward pass.
        self._register_post_backward_hooks()

        outputs = self.module(*args, **kwargs)

        if self.reshard_after_forward:
            self._free_full_params()

        # Switch to main FP32 param shard. We maintain this invariant throughout
        # the code, i.e., ``p.data == p._fp32_shard`` after each function. This
        # also ensures that after the first forward, the optimizer state will be
        # initialized with the correct dtype and (sharded) size, since optimizer
        # state is typically initialized lazily in ``optim.step()``.
        self._use_fp32_param_shard()

        # Register pre-backward hooks to all-gather the params for the backward
        # pass (if needed).
        outputs = self._register_pre_backward_hooks(outputs)

        # Done with a forward pass.
        self.training_state = TrainingState.IDLE

        return outputs

    def _register_pre_backward_hooks(self, outputs: Any) -> Any:
        """Register pre-backward hook to run before the wrapped module's
        backward. Hooks should be attached to all outputs from the forward."""
        if not torch.is_grad_enabled():
            return outputs  # don't register hooks if grad isn't enabled

        pre_backward_hook_has_run = [False]

        def _pre_backward_hook(*unused: Any) -> None:
            if pre_backward_hook_has_run[0]:
                return  # only run once
            pre_backward_hook_has_run[0] = True

            # Start of a backward pass.
            self.training_state = TrainingState.BACKWARD

            # All-gather full parameters.
            if self.reshard_after_forward:
                self._rebuild_full_params()
            else:
                self._use_full_params()
839

840
841
842
843
            # Make sure p.grad has the correct size/device (or set it to None).
            self._prep_grads_for_backward()

        def _register_hook(t: torch.Tensor) -> torch.Tensor:
844
845
            if t.requires_grad:
                t.register_hook(_pre_backward_hook)
846
847
848
849
850
851
852
853
854
855
856
            return t

        # Attach hooks to Tensor outputs.
        outputs = apply_to_tensors(_register_hook, outputs)

        return outputs

    def _register_post_backward_hooks(self) -> None:
        """Register backward hooks to reshard params and reduce-scatter grads."""
        if not torch.is_grad_enabled():
            return  # don't register grad hooks if grad isn't enabled
857
858
859
860
861
        if self._is_root:
            # This actually means that only root instance has this field
            # defined. Accidentally accessing this field will assert on all
            # other instances, giving us a nice bug checker.
            self._post_backward_callback_queued = False
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        for p in self.params:
            if p.requires_grad:
                if hasattr(p, "_shard_bwd_hook"):
                    p._shard_bwd_hook[1].remove()  # remove existing handle
                p_tmp = p.expand_as(p)
                grad_acc = p_tmp.grad_fn.next_functions[0][0]
                handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
                p._shard_bwd_hook = (grad_acc, handle)

    @torch.no_grad()
    def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
        """
        At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
        full gradient for the local batch. The reduce-scatter op will replace
        ``param.grad`` with a single shard of the summed gradient across all
        GPUs. This shard will align with the current GPU rank. For example::

            before reduce_scatter:
                param.grad (GPU #0): [1, 2, 3, 4]
                param.grad (GPU #1): [5, 6, 7, 8]

            after reduce_scatter:
                param.grad (GPU #0): [6, 8]    # 1+5, 2+6
                param.grad (GPU #1): [10, 12]  # 3+7, 4+8

        The local GPU's ``optim.step`` is responsible for updating a single
        shard of params, also corresponding to the current GPU's rank. This
        alignment is created by :func:`_shard_parameters_`, which ensures that
        the local optimizer only sees the relevant parameter shard.
        """
        self.assert_state(TrainingState.BACKWARD)
        if param.grad is None:
            return
        if param.grad.requires_grad:
            raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")

898
899
900
901
902
903
904
        if not self._is_root or self._require_backward_grad_sync:
            # Free full params. As a special case, we don't free the full params
            # on the root instance when in a ``no_sync`` context (as indicated
            # by ``self._require_backward_grad_sync``), since we will need the
            # params again immediately for the next forward.
            self._free_full_params([param])

905
906
907
908
909
910
        if self.mixed_precision:
            # This is a no-op if reshard_after_forward is True, since we already
            # free the param shard when rebuilding the full params in the
            # pre_backward_hook.
            self._free_fp16_param_shard([param])

911
912
913
        # Switch to FP32 shard after backward.
        self._use_fp32_param_shard([param])

914
915
916
917
918
919
920
        # (try to) Enqueue a callback at the end of the backward pass to ensure that all
        # post-backward work has finished. We only need one callback and all instances
        # of FSDP (root and children) make this attempt here to queue to ensure it is queued
        # no matter which instance(s) has(have) params.
        assert self._queue_wait_for_post_backward_closure is not None or not self._is_root
        if self._queue_wait_for_post_backward_closure is not None:
            self._queue_wait_for_post_backward_closure()
921

922
        if not self._require_backward_grad_sync:
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
            return

        # Wait for all work in the current stream to finish, then start the
        # reductions in post_backward stream.
        self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self._streams["post_backward"]):
            orig_grad_data = param.grad.data

            if self.mixed_precision and self.fp32_reduce_scatter:
                # Cast grad to FP32.
                param.grad.data = param.grad.data.to(param.dtype)

            if self.world_size > 1:
                # Average grad by world_size for consistency with PyTorch DDP.
                param.grad.data.div_(self.world_size)

            callback_fn = functools.partial(self._post_reduction_hook, param)
            if param._is_sharded:
                assert param._is_sharded
                assert self._reducer is not None
                grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
                self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
            else:
                # Currently the only way for _is_sharded to be False is if
                # world_size == 1. This could be relaxed in the future, in which
                # case grads should be all-reduced here.
                assert self.world_size == 1
                callback_fn(param.grad.data)

            # After _post_backward_hook returns, orig_grad_data will eventually
            # go out of scope, at which point it could otherwise be freed for
            # further reuse by the main stream while the div/reduce_scatter/copy
            # are underway in the post_backward stream. See:
            # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
            orig_grad_data.record_stream(self._streams["post_backward"])

    def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
        """Hook to call on each param after the reduce-scatter."""
        assert torch.cuda.current_stream() == self._streams["post_backward"]
        assert param.grad is not None
        self.assert_state(TrainingState.BACKWARD)
        param.grad.data = reduced_grad
        # Cast grad to param's dtype (typically FP32). Note: we do this
        # before the move_grads_to_cpu step so that this entire hook remains
        # non-blocking. The downside is a bit more D2H transfer in that case.
        if self.mixed_precision:
            param.grad.data = param.grad.data.to(dtype=param.data.dtype)
        # Optionally move gradients to CPU, typically used if one is running
        # the optimizer on the CPU.
        if self.move_grads_to_cpu:
            param._cpu_grad.copy_(param.grad.data, non_blocking=True)
            param.grad.data = param._cpu_grad
        # Don't let this memory get reused until after the transfers.
        reduced_grad.record_stream(torch.cuda.current_stream())

978
979
    def _queue_wait_for_post_backward(self) -> None:
        """Try to queue a `wait_for_post_backward` callback.
980
981
982
983

        Only called on root and only queue one callback. But can be called by
        children FSDPs via a closure in case the root instance doesn't own any
        params.
984
985
986
987
988
989
990
        """
        assert self._is_root
        self.assert_state(TrainingState.BACKWARD)
        if not self._post_backward_callback_queued:
            self._post_backward_callback_queued = True
            Variable._execution_engine.queue_callback(self._wait_for_post_backward)

991
992
    @torch.no_grad()
    def _wait_for_post_backward(self) -> None:
993
        """Wait for post-backward to finish. Only called on root instance."""
994
995
        assert self._is_root
        self.assert_state(TrainingState.BACKWARD)
996
997
998
999
1000
1001
1002
1003
1004
        if self._require_backward_grad_sync:
            # Flush any unreduced buckets in the post_backward stream.
            with torch.cuda.stream(self._streams["post_backward"]):
                assert self._reducer is not None
                self._reducer.flush()
            torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
            if self.move_grads_to_cpu:
                # Wait for the non-blocking GPU -> CPU grad transfers to finish.
                torch.cuda.current_stream().synchronize()
1005
1006
1007
1008
1009
        # A backward pass is done, update root and nested FSDP's flags.
        for m in self.modules():  # includes self
            if isinstance(m, FullyShardedDataParallel):
                m.assert_state(TrainingState.BACKWARD)
                m.training_state = TrainingState.IDLE
1010
1011

    @torch.no_grad()
1012
    def _rebuild_full_params(self, full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        """
        Gather all shards of params.

        Args:
            full_precision (bool, Optional): by default params will be gathered
                in ``compute_dtype`` (e.g., FP16), unless *full_precision* is
                ``True``, in which case they will be gathered in full precision
                (e.g., FP32), possibly in fresh storage.

        Returns:
1023
            A list of tuples, where the first element is the full-sized param
1024
            and the second element is a bool indicating if it's safe for the
1025
1026
            caller to free the full-sized param. This will be ``None`` if
            ``full_precision=False`` and the full params are already gathered.
1027
1028
        """
        output_tensors: List[Tuple[torch.Tensor, bool]] = []
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055

        def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
            if custom_output_tensor is not None:
                assert p._is_sharded
                p.data = custom_output_tensor
                output_tensors.append((p.data, True))
            elif not p._is_sharded:
                if self.mixed_precision and not full_precision:
                    p.data = p._fp16_shard
                    output_tensors.append((p.data, True))
                else:
                    # Here p.data == p._fp32_shard, so it's not safe to free.
                    output_tensors.append((p.data, False))
            else:
                p.data = p._full_param_padded
                output_tensors.append((p.data, True))
            # Trim any padding and reshape to match original size.
            p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

        # Early exit if we already have full params and don't need full precision.
        if self.has_full_params and not full_precision:
            for p in self.params:
                update_p_data()
            return output_tensors

        self.has_full_params = True

1056
        with torch.cuda.stream(self._streams["all_gather"]):
1057
            if self.mixed_precision and not full_precision:
1058
1059
1060
                self._cast_fp32_param_shards_to_fp16()

            for p in self.params:
1061
                if not p._is_sharded:  # e.g., when world_size == 1
1062
                    update_p_data()
1063
                else:
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
                    # If self.cpu_offload and full_precision, we need to cast
                    # the FP32 CPU param to CUDA for the all-gather.
                    p_data = p.data.to(p._full_param_padded.device)

                    p_size = p._full_param_padded.size()
                    assert p_size.numel() % self.world_size == 0
                    if not self.mixed_precision or not full_precision:
                        if p._full_param_padded.storage().size() != p_size.numel():
                            # Allocate based on full size from all shards.
                            alloc_storage_(p._full_param_padded, size=p_size)
                        output_tensor = p._full_param_padded
                    else:
                        # Allocate fresh tensor in full precision.
                        output_tensor = p_data.new_zeros(p_size)
1078

1079
1080
1081
                    # Fill output_tensor with (p.data for each shard in self.world_size)
                    chunks = list(output_tensor.chunk(self.world_size))
                    dist.all_gather(chunks, p_data, group=self.process_group)
1082

1083
1084
                    # Set p.data = output_tensor (with padding trimmed)
                    update_p_data(output_tensor)
1085

1086
1087
                    if self.mixed_precision and not full_precision:
                        self._free_fp16_param_shard([p])
1088
        torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
1089
        return output_tensors
1090
1091
1092

    @torch.no_grad()
    def _use_full_params(self) -> None:
1093
1094
        """
        Switch p.data pointers to use the full params.
1095

1096
        Note: this assumes full params are already gathered.
1097
        """
1098
        assert self.has_full_params
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
        for p in self.params:
            if not p._is_sharded:
                if self.mixed_precision:
                    assert p._fp16_shard.storage().size() != 0
                    p.data = p._fp16_shard
            else:
                assert p._full_param_padded.storage().size() != 0
                p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)

    @torch.no_grad()
    def _prep_grads_for_backward(self) -> None:
        """Make sure p.grad has the correct size/device, otherwise set it to None."""
        for p in self.params:
            if p.grad is not None and (p.grad.size() != p._orig_size or p.grad.device != p.data.device):
                p.grad = None

    @torch.no_grad()
    def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
        """Free up storage for full parameters."""
        if params is None:
            params = self.params
1120
        self.has_full_params = False
1121
1122
1123
        current_stream = torch.cuda.current_stream()
        with torch.cuda.stream(self._streams["all_gather"]):
            for p in params:
1124
                if not p._is_sharded:  # e.g., world_size == 1
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
                    if self.mixed_precision:
                        self._free_fp16_param_shard([p])
                    continue
                # There may be external references to the Tensor Storage that we
                # can't modify, such as references that are created by
                # ctx.save_for_backward in the forward pass. Thus when we
                # unshard parameters, we should reuse the original Tensor
                # Storage object and unshard it in-place. For now, just resize
                # the Storage to 0 to save memory.
                p._full_param_padded.record_stream(current_stream)
                free_storage_(p._full_param_padded)

    @torch.no_grad()
    def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
        """Use FP32 shard for a list of params."""
        if params is None:
            params = self.params
        for p in params:
            p.data = p._fp32_shard

    @torch.no_grad()
    def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
        """Cast FP32 param shard to FP16 for a list of params."""
        if params is None:
            params = self.params
        with torch.cuda.stream(self._streams["fp32_to_fp16"]):
            for p in params:
                assert p._fp16_shard is not None
                alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
                p._fp16_shard.copy_(
                    # If cpu_offload is True, this will be non-blocking because
                    # _fp32_shard is pinned, otherwise it's a no-op.
                    p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
                )
                p.data = p._fp16_shard
        torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

    @torch.no_grad()
    def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
        """Free storage for FP16 shards for a list of params."""
        if params is None:
            params = self.params
        current_stream = torch.cuda.current_stream()
        for p in params:
            if p._fp16_shard is not None:
                # _fp16_shard is allocated in _fp32_to_fp16_stream, so we can't
                # free it until the work in the current stream completes.
                p._fp16_shard.record_stream(current_stream)
                free_storage_(p._fp16_shard)

    def assert_state(self, state: TrainingState) -> None:
        """Assert we are in the given state."""
        assert (
            self.training_state == state
        ), f"expected to be in state {state} but current state is {self.training_state}"


@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
    """
    Cast any Tensors in *args or **kwargs to FP16.
    """
1187
1188
1189
1190
1191
1192
1193

    def fn(x: torch.Tensor) -> torch.Tensor:
        if x.dtype is torch.float32:
            return x.half()
        return x

    return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
1194
1195
1196
1197
1198


def cast_buffers_(
    module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
1199
    """Cast all of module.named_buffers to device and floating point buffers to dtype."""
1200
    # if buffers are already on the right device and/or dtype this is just python loop cost
1201
    assert dtype in {torch.float32, torch.float16}  # assumes compute_dtype == float16
1202
1203
    for key, buf in module.named_buffers(recurse=False):
        if buf is not None:
1204
1205
1206
1207
            buf = buf.to(device=device)
            if torch.is_floating_point(buf):
                buf = buf.to(dtype=dtype)
            setattr(module, key, buf)
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225


def free_storage_(data: torch.Tensor) -> None:
    """Free underlying storage of a Tensor."""
    if data.storage().size() > 0:
        # Since we're modifying the Tensor's Storage directly, make sure the Tensor
        # is the sole occupant of the Storage.
        assert data.storage_offset() == 0
        data.storage().resize_(0)


@torch.no_grad()
def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
    """Allocate storage for a tensor."""
    if data.storage().size() == size.numel():  # no need to reallocate
        return
    assert data.storage().size() == 0
    data.storage().resize_(size.numel())
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245


def _post_state_dict_hook(
    module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, torch.Tensor]":
    if module.training_state == TrainingState.SUMMON_FULL_PARAMS:
        # We copy the state_dict since full param will be freed after
        # we exit the summon_full_params() context.
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].clone()

    # Remove "_fsdp_wrapped_module." prefix
    replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
    return state_dict


def _pre_load_state_dict_hook(
    state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
    replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")