Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
6b2897ca
Unverified
Commit
6b2897ca
authored
Feb 25, 2021
by
Myle Ott
Committed by
GitHub
Feb 25, 2021
Browse files
[cleanup] FSDP docstrings (#428)
parent
2478a9ad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
45 deletions
+69
-45
docs/source/conf.py
docs/source/conf.py
+5
-1
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+64
-44
No files found.
docs/source/conf.py
View file @
6b2897ca
...
@@ -37,7 +37,11 @@ release = "0.3.0"
...
@@ -37,7 +37,11 @@ release = "0.3.0"
# Add any Sphinx extension module names here, as strings. They can be
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
# ones.
extensions
=
[
"sphinx.ext.autodoc"
,
"sphinx.ext.autosectionlabel"
]
extensions
=
[
"sphinx.ext.autodoc"
,
"sphinx.ext.autosectionlabel"
,
"sphinx.ext.napoleon"
,
# support NumPy and Google style docstrings
]
# Add any paths that contain templates here, relative to this directory.
# Add any paths that contain templates here, relative to this directory.
templates_path
=
[
"_templates"
]
templates_path
=
[
"_templates"
]
...
...
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
6b2897ca
...
@@ -63,6 +63,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -63,6 +63,7 @@ class FullyShardedDataParallel(nn.Module):
Usage::
Usage::
torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module)
sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
x = sharded_module(x, y=3, z=torch.Tensor([1]))
...
@@ -72,11 +73,12 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -72,11 +73,12 @@ class FullyShardedDataParallel(nn.Module):
It is also possible to shard individual layers separately and have an outer
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
wrapper handle any leftover parameters. This can be helpful to further
reduce memory usage and to improve training speed by distributing the
reduce GPU memory usage, reduce system memory usage when initializing large
unsharding (all-gather) across the forward pass. For example::
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
sharded_model = FullyShardedDataParallel(
sharded_model = FullyShardedDataParallel(
nn.Sequential(
nn.Sequential(
# doesn't have to be nn.Sequential
nn.Linear(5, 100),
nn.Linear(5, 100),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
...
@@ -84,46 +86,49 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -84,46 +86,49 @@ class FullyShardedDataParallel(nn.Module):
)
)
)
)
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
Args:
Args:
module (nn.Module):
module (nn.Module):
module to checkpoint
module to checkpoint
process_group (Optional):
process_group (Optional):
process group for sharding
process group for sharding
reshard_after_forward (bool, Optional):
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters
if ``True``, reshard parameters
after the forward pass. This saves
after the forward pass. This saves memory but slows training. This
memory but slows training. This is only relevant when resharding
is only relevant when resharding
individual layers.
individual layers.
mixed_precision (bool, Optional):
mixed_precision (bool, Optional):
if ``True``, inputs, activations and
if ``True``, inputs, activations and gradients will be kept in FP16;
gradients will be kept in FP16; computation and communication will
computation and communication will occur in FP16; and a (sharded)
occur in FP16; and a (sharded) master copy of the model weights will
master copy of the model weights will be maintained in FP32.
be maintained in FP32.
fp32_reduce_scatter (bool, Optional):
fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter
if ``True``, then reduce-scatter gradients in FP32. This is only
gradients in FP32. This is only relevant when *``mixed_precision``*
relevant when *``mixed_precision``* is ``True``.
is ``True``.
flatten_parameters (bool, Optional):
flatten_parameters (bool, Optional):
if ``True``, flatten parameters
if ``True``, flatten parameters
into a single contiguous tensor,
into a single contiguous tensor,
which improves training speed.
which improves training speed.
cpu_offload (bool, Optional):
cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU.
if ``True``, offload FP32 params to CPU.
This is only relevant when
This is only relevant when
*``mixed_precision``* is ``True``.
*``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional):
compute_dtype (torch.dtype, Optional):
dtype for full parameters for
dtype for full parameters for computation. This defaults to
computation. This defaults to ``torch.float32`` unless
``torch.float32`` unless *``mixed_precision``* is set, in which case
*``mixed_precision``* is set, in which case it defaults to
it defaults to ``torch.float16``.
``torch.float16``.
move_grads_to_cpu (bool, Optional):
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after
move gradient shard to CPU after
reduction. This is useful when
reduction. This is useful when
combined with CPU-based optimizers.
combined with CPU-based optimizers.
It defaults to the value of
It defaults to the value of
*``cpu_offload``*.
*``cpu_offload``*.
bucket_cap_mb (int, Optional):
bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that
FSDP will bucket parameters so that
gradient reduction can
gradient reduction can
potentially overlap with backward
potentially overlap with backward
computation. bucket_cap_mb
computation. bucket_cap_mb
controls the bucket size in MegaBytes
controls the bucket size in MegaBytes
(MB). Buckets are sub-divided
(MB). Buckets are sub-divided
based on world_size, so the max shard
based on world_size, so the max shard
size is roughly
size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable
``bucket_cap_mb / world_size``. Values <= 0 disable
bucketing.
bucketing.
Default: 25.
Default: 25.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -207,22 +212,27 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -207,22 +212,27 @@ class FullyShardedDataParallel(nn.Module):
# filter_params_fn: Callable[[Any], Any] = None,
# filter_params_fn: Callable[[Any], Any] = None,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
Clip all gradients at this point in time. The norm is computed over all
concatenated into a single vector. Gradients are modified in-place.
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
Arg
ument
s:
Args:
max_norm (float or int): max norm of the gradients
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
Returns:
Returns:
Total norm of the parameters (viewed as a single vector).
Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
.. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
handles the partitioning and multiple devices per rank under the
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters
hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
"""
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
training_state
==
TrainingState
.
IDLE
assert
self
.
training_state
==
TrainingState
.
IDLE
...
@@ -367,7 +377,11 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -367,7 +377,11 @@ class FullyShardedDataParallel(nn.Module):
"""
"""
Returns the whole (unsharded) state of the module. Parameters are not
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32
wrapped Module without any sharding-specific logic. Returned tensors
will always be typed float32.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
"""
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
self
.
_lazy_init
()
...
@@ -399,7 +413,12 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -399,7 +413,12 @@ class FullyShardedDataParallel(nn.Module):
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
)
->
NamedTuple
:
"""Load a whole (unsharded) state_dict."""
"""
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
self
.
_lazy_init
()
self
.
_rebuild_full_params
()
self
.
_rebuild_full_params
()
...
@@ -550,9 +569,10 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -550,9 +569,10 @@ class FullyShardedDataParallel(nn.Module):
def
_set_is_root
(
self
)
->
None
:
def
_set_is_root
(
self
)
->
None
:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child instances share the same process group.
Also sets self.children_share_process_group = True if all child
If some child instances use a different process group, self.clip_grad_norm_ will raise an error.
instances share the same process group. If some child instances use a
"""
different process group, self.clip_grad_norm_ will raise an error.
"""
if
self
.
_is_root
is
not
None
:
if
self
.
_is_root
is
not
None
:
return
return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment