Unverified Commit 6b2897ca authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[cleanup] FSDP docstrings (#428)

parent 2478a9ad
...@@ -37,7 +37,11 @@ release = "0.3.0" ...@@ -37,7 +37,11 @@ release = "0.3.0"
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosectionlabel"] extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosectionlabel",
"sphinx.ext.napoleon", # support NumPy and Google style docstrings
]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"] templates_path = ["_templates"]
......
...@@ -63,6 +63,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -63,6 +63,7 @@ class FullyShardedDataParallel(nn.Module):
Usage:: Usage::
torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module) sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1])) x = sharded_module(x, y=3, z=torch.Tensor([1]))
...@@ -72,11 +73,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -72,11 +73,12 @@ class FullyShardedDataParallel(nn.Module):
It is also possible to shard individual layers separately and have an outer It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further wrapper handle any leftover parameters. This can be helpful to further
reduce memory usage and to improve training speed by distributing the reduce GPU memory usage, reduce system memory usage when initializing large
unsharding (all-gather) across the forward pass. For example:: models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
sharded_model = FullyShardedDataParallel( sharded_model = FullyShardedDataParallel(
nn.Sequential( nn.Sequential( # doesn't have to be nn.Sequential
nn.Linear(5, 100), nn.Linear(5, 100),
FullyShardedDataParallel(nn.Linear(100, 100)), FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)), FullyShardedDataParallel(nn.Linear(100, 100)),
...@@ -84,46 +86,49 @@ class FullyShardedDataParallel(nn.Module): ...@@ -84,46 +86,49 @@ class FullyShardedDataParallel(nn.Module):
) )
) )
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
Args: Args:
module (nn.Module): module (nn.Module):
module to checkpoint module to checkpoint
process_group (Optional): process_group (Optional):
process group for sharding process group for sharding
reshard_after_forward (bool, Optional): reshard_after_forward (bool, Optional):
if ``True``, reshard parameters if ``True``, reshard parameters after the forward pass. This saves
after the forward pass. This saves memory but slows training. This memory but slows training. This is only relevant when resharding
is only relevant when resharding individual layers. individual layers.
mixed_precision (bool, Optional): mixed_precision (bool, Optional):
if ``True``, inputs, activations and if ``True``, inputs, activations and gradients will be kept in FP16;
gradients will be kept in FP16; computation and communication will computation and communication will occur in FP16; and a (sharded)
occur in FP16; and a (sharded) master copy of the model weights will master copy of the model weights will be maintained in FP32.
be maintained in FP32.
fp32_reduce_scatter (bool, Optional): fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter if ``True``, then reduce-scatter gradients in FP32. This is only
gradients in FP32. This is only relevant when *``mixed_precision``* relevant when *``mixed_precision``* is ``True``.
is ``True``.
flatten_parameters (bool, Optional): flatten_parameters (bool, Optional):
if ``True``, flatten parameters if ``True``, flatten parameters into a single contiguous tensor,
into a single contiguous tensor, which improves training speed. which improves training speed.
cpu_offload (bool, Optional): cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU. if ``True``, offload FP32 params to CPU. This is only relevant when
This is only relevant when *``mixed_precision``* is ``True``. *``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional): compute_dtype (torch.dtype, Optional):
dtype for full parameters for dtype for full parameters for computation. This defaults to
computation. This defaults to ``torch.float32`` unless ``torch.float32`` unless *``mixed_precision``* is set, in which case
*``mixed_precision``* is set, in which case it defaults to it defaults to ``torch.float16``.
``torch.float16``.
move_grads_to_cpu (bool, Optional): move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after move gradient shard to CPU after reduction. This is useful when
reduction. This is useful when combined with CPU-based optimizers. combined with CPU-based optimizers. It defaults to the value of
It defaults to the value of *``cpu_offload``*. *``cpu_offload``*.
bucket_cap_mb (int, Optional): bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that FSDP will bucket parameters so that gradient reduction can
gradient reduction can potentially overlap with backward potentially overlap with backward computation. bucket_cap_mb
computation. bucket_cap_mb controls the bucket size in MegaBytes controls the bucket size in MegaBytes (MB). Buckets are sub-divided
(MB). Buckets are sub-divided based on world_size, so the max shard based on world_size, so the max shard size is roughly
size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable ``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
bucketing. Default: 25. Default: 25.
""" """
def __init__( def __init__(
...@@ -207,22 +212,27 @@ class FullyShardedDataParallel(nn.Module): ...@@ -207,22 +212,27 @@ class FullyShardedDataParallel(nn.Module):
# filter_params_fn: Callable[[Any], Any] = None, # filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were Clip all gradients at this point in time. The norm is computed over all
concatenated into a single vector. Gradients are modified in-place. gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
Arguments: Args:
max_norm (float or int): max norm of the gradients max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
Returns: Returns:
Total norm of the parameters (viewed as a single vector). Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank .. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads handles the partitioning and multiple devices per rank under the
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
.. warning: This needs to be called on all ranks, since synchronization primitives will be used calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
""" """
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE assert self.training_state == TrainingState.IDLE
...@@ -367,7 +377,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -367,7 +377,11 @@ class FullyShardedDataParallel(nn.Module):
""" """
Returns the whole (unsharded) state of the module. Parameters are not Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32 wrapped Module without any sharding-specific logic. Returned tensors
will always be typed float32.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
""" """
torch.cuda.synchronize() torch.cuda.synchronize()
self._lazy_init() self._lazy_init()
...@@ -399,7 +413,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -399,7 +413,12 @@ class FullyShardedDataParallel(nn.Module):
def load_state_dict( def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple: ) -> NamedTuple:
"""Load a whole (unsharded) state_dict.""" """
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize() torch.cuda.synchronize()
self._lazy_init() self._lazy_init()
self._rebuild_full_params() self._rebuild_full_params()
...@@ -550,9 +569,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -550,9 +569,10 @@ class FullyShardedDataParallel(nn.Module):
def _set_is_root(self) -> None: def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel` """If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`. instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child instances share the same process group. Also sets self.children_share_process_group = True if all child
If some child instances use a different process group, self.clip_grad_norm_ will raise an error. instances share the same process group. If some child instances use a
""" different process group, self.clip_grad_norm_ will raise an error.
"""
if self._is_root is not None: if self._is_root is not None:
return return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment