Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
6b2897ca
Unverified
Commit
6b2897ca
authored
Feb 25, 2021
by
Myle Ott
Committed by
GitHub
Feb 25, 2021
Browse files
[cleanup] FSDP docstrings (#428)
parent
2478a9ad
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
45 deletions
+69
-45
docs/source/conf.py
docs/source/conf.py
+5
-1
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+64
-44
No files found.
docs/source/conf.py
View file @
6b2897ca
...
...
@@ -37,7 +37,11 @@ release = "0.3.0"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions
=
[
"sphinx.ext.autodoc"
,
"sphinx.ext.autosectionlabel"
]
extensions
=
[
"sphinx.ext.autodoc"
,
"sphinx.ext.autosectionlabel"
,
"sphinx.ext.napoleon"
,
# support NumPy and Google style docstrings
]
# Add any paths that contain templates here, relative to this directory.
templates_path
=
[
"_templates"
]
...
...
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
6b2897ca
...
...
@@ -63,6 +63,7 @@ class FullyShardedDataParallel(nn.Module):
Usage::
torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
...
...
@@ -72,11 +73,12 @@ class FullyShardedDataParallel(nn.Module):
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
reduce memory usage and to improve training speed by distributing the
unsharding (all-gather) across the forward pass. For example::
reduce GPU memory usage, reduce system memory usage when initializing large
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
sharded_model = FullyShardedDataParallel(
nn.Sequential(
nn.Sequential(
# doesn't have to be nn.Sequential
nn.Linear(5, 100),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
...
...
@@ -84,46 +86,49 @@ class FullyShardedDataParallel(nn.Module):
)
)
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
Args:
module (nn.Module):
module to checkpoint
process_group (Optional):
process group for sharding
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters
after the forward pass. This saves memory but slows training. This
is only relevant when resharding
individual layers.
if ``True``, reshard parameters
after the forward pass. This saves
memory but slows training. This is only relevant when resharding
individual layers.
mixed_precision (bool, Optional):
if ``True``, inputs, activations and
gradients will be kept in FP16; computation and communication will
occur in FP16; and a (sharded) master copy of the model weights will
be maintained in FP32.
if ``True``, inputs, activations and gradients will be kept in FP16;
computation and communication will occur in FP16; and a (sharded)
master copy of the model weights will be maintained in FP32.
fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter
gradients in FP32. This is only relevant when *``mixed_precision``*
is ``True``.
if ``True``, then reduce-scatter gradients in FP32. This is only
relevant when *``mixed_precision``* is ``True``.
flatten_parameters (bool, Optional):
if ``True``, flatten parameters
into a single contiguous tensor,
which improves training speed.
if ``True``, flatten parameters
into a single contiguous tensor,
which improves training speed.
cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU.
This is only relevant when
*``mixed_precision``* is ``True``.
if ``True``, offload FP32 params to CPU.
This is only relevant when
*``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for
computation. This defaults to ``torch.float32`` unless
*``mixed_precision``* is set, in which case it defaults to
``torch.float16``.
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after
reduction. This is useful when
combined with CPU-based optimizers.
It defaults to the value of
*``cpu_offload``*.
move gradient shard to CPU after
reduction. This is useful when
combined with CPU-based optimizers.
It defaults to the value of
*``cpu_offload``*.
bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that
gradient reduction can
potentially overlap with backward
computation. bucket_cap_mb
controls the bucket size in MegaBytes
(MB). Buckets are sub-divided
based on world_size, so the max shard
size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable
bucketing.
Default: 25.
FSDP will bucket parameters so that
gradient reduction can
potentially overlap with backward
computation. bucket_cap_mb
controls the bucket size in MegaBytes
(MB). Buckets are sub-divided
based on world_size, so the max shard
size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable
bucketing.
Default: 25.
"""
def
__init__
(
...
...
@@ -207,22 +212,27 @@ class FullyShardedDataParallel(nn.Module):
# filter_params_fn: Callable[[Any], Any] = None,
)
->
torch
.
Tensor
:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Clip all gradients at this point in time. The norm is computed over all
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
Arg
ument
s:
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
.. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
handles the partitioning and multiple devices per rank under the
hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
training_state
==
TrainingState
.
IDLE
...
...
@@ -367,7 +377,11 @@ class FullyShardedDataParallel(nn.Module):
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32
wrapped Module without any sharding-specific logic. Returned tensors
will always be typed float32.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
...
...
@@ -399,7 +413,12 @@ class FullyShardedDataParallel(nn.Module):
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
"""Load a whole (unsharded) state_dict."""
"""
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
self
.
_rebuild_full_params
()
...
...
@@ -550,8 +569,9 @@ class FullyShardedDataParallel(nn.Module):
def
_set_is_root
(
self
)
->
None
:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child instances share the same process group.
If some child instances use a different process group, self.clip_grad_norm_ will raise an error.
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
"""
if
self
.
_is_root
is
not
None
:
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment