Unverified Commit 67bf5bf8 authored by foreveronehundred's avatar foreveronehundred Committed by GitHub
Browse files

[FSDP] Add an arg for FSDP __init__ (#926)

* [FSDP] Add an arg for FSDP __init__

Add an arg, disable_reshard_on_root, for FSDP __init__ to handle the following issue
https://github.com/facebookresearch/fairscale/issues/878


For some cases (models wrapped by autowrap), the parameters (of root modules) needs to be sharded, and reshard_after_forward should not be set to False.
"disable_reshard_on_root" is for users to choose whether to force reshard_after_forward of root modules to be False or not.

* Update fully_sharded_data_parallel.py

Modified the description of the feature to explain more clear.

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Update the comments for disable_reshard_on_root
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Modified the comments

Modified the comments of disable_reshard_on_root
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 7202115e
...@@ -204,6 +204,16 @@ class FullyShardedDataParallel(nn.Module): ...@@ -204,6 +204,16 @@ class FullyShardedDataParallel(nn.Module):
if ``True``, reshard parameters after the forward pass. This saves if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding memory but slows training. This is only relevant when resharding
individual layers. individual layers.
disable_reshard_on_root (bool, Optional):
If ``True``, ``reshard_after_forward`` will be set to ``False`` if the module is a
FSDP root module to improve performance. For some cases, we do not reshard the full
parameters of an FSDP root module since those parameters are needed immediately for the
backward pass.
If ``False``, the performance will be lower, but it is needed because it helps to
save memory. Consider a case that an FSDP root module is a submodule of a model.
Backward pass may not start immediate after the FSDP root module finishes its forward.
So, reshard the parameters for the FSDP root modules can help to save memory in this case.
Default: True.
mixed_precision (bool, Optional): mixed_precision (bool, Optional):
if ``True``, inputs, activations and gradients will be kept in FP16; if ``True``, inputs, activations and gradients will be kept in FP16;
computation and communication will occur in FP16; and a (sharded) computation and communication will occur in FP16; and a (sharded)
...@@ -303,6 +313,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -303,6 +313,7 @@ class FullyShardedDataParallel(nn.Module):
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName # The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter, process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True, reshard_after_forward: bool = True,
disable_reshard_on_root: bool = True,
mixed_precision: bool = False, mixed_precision: bool = False,
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True, flatten_parameters: bool = True,
...@@ -365,6 +376,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -365,6 +376,7 @@ class FullyShardedDataParallel(nn.Module):
"parameter uses all the available ranks for the optimized performance." "parameter uses all the available ranks for the optimized performance."
) )
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters self.flatten_parameters = flatten_parameters
...@@ -1150,6 +1162,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1150,6 +1162,7 @@ class FullyShardedDataParallel(nn.Module):
# applies recursively, we only call this from the root instance. # applies recursively, we only call this from the root instance.
self._cast_buffers() self._cast_buffers()
if self.disable_reshard_on_root:
# Don't free the full params for the outer-most (root) instance, # Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the # since those params will be needed immediately after for the
# backward pass. # backward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment