Unverified Commit 6b2897ca authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[cleanup] FSDP docstrings (#428)

parent 2478a9ad
......@@ -37,7 +37,11 @@ release = "0.3.0"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosectionlabel"]
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosectionlabel",
"sphinx.ext.napoleon", # support NumPy and Google style docstrings
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
......
......@@ -63,6 +63,7 @@ class FullyShardedDataParallel(nn.Module):
Usage::
torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
......@@ -72,11 +73,12 @@ class FullyShardedDataParallel(nn.Module):
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
reduce memory usage and to improve training speed by distributing the
unsharding (all-gather) across the forward pass. For example::
reduce GPU memory usage, reduce system memory usage when initializing large
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
sharded_model = FullyShardedDataParallel(
nn.Sequential(
nn.Sequential( # doesn't have to be nn.Sequential
nn.Linear(5, 100),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
......@@ -84,46 +86,49 @@ class FullyShardedDataParallel(nn.Module):
)
)
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
Args:
module (nn.Module):
module to checkpoint
process_group (Optional):
process group for sharding
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters
after the forward pass. This saves memory but slows training. This
is only relevant when resharding individual layers.
if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding
individual layers.
mixed_precision (bool, Optional):
if ``True``, inputs, activations and
gradients will be kept in FP16; computation and communication will
occur in FP16; and a (sharded) master copy of the model weights will
be maintained in FP32.
if ``True``, inputs, activations and gradients will be kept in FP16;
computation and communication will occur in FP16; and a (sharded)
master copy of the model weights will be maintained in FP32.
fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter
gradients in FP32. This is only relevant when *``mixed_precision``*
is ``True``.
if ``True``, then reduce-scatter gradients in FP32. This is only
relevant when *``mixed_precision``* is ``True``.
flatten_parameters (bool, Optional):
if ``True``, flatten parameters
into a single contiguous tensor, which improves training speed.
if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed.
cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU.
This is only relevant when *``mixed_precision``* is ``True``.
if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for
computation. This defaults to ``torch.float32`` unless
*``mixed_precision``* is set, in which case it defaults to
``torch.float16``.
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after
reduction. This is useful when combined with CPU-based optimizers.
It defaults to the value of *``cpu_offload``*.
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
*``cpu_offload``*.
bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that
gradient reduction can potentially overlap with backward
computation. bucket_cap_mb controls the bucket size in MegaBytes
(MB). Buckets are sub-divided based on world_size, so the max shard
size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable
bucketing. Default: 25.
FSDP will bucket parameters so that gradient reduction can
potentially overlap with backward computation. bucket_cap_mb
controls the bucket size in MegaBytes (MB). Buckets are sub-divided
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
Default: 25.
"""
def __init__(
......@@ -207,22 +212,27 @@ class FullyShardedDataParallel(nn.Module):
# filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Clip all gradients at this point in time. The norm is computed over all
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
Arguments:
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
.. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
handles the partitioning and multiple devices per rank under the
hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE
......@@ -367,7 +377,11 @@ class FullyShardedDataParallel(nn.Module):
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32
wrapped Module without any sharding-specific logic. Returned tensors
will always be typed float32.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize()
self._lazy_init()
......@@ -399,7 +413,12 @@ class FullyShardedDataParallel(nn.Module):
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""Load a whole (unsharded) state_dict."""
"""
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize()
self._lazy_init()
self._rebuild_full_params()
......@@ -550,9 +569,10 @@ class FullyShardedDataParallel(nn.Module):
def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child instances share the same process group.
If some child instances use a different process group, self.clip_grad_norm_ will raise an error.
"""
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
"""
if self._is_root is not None:
return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment