Unverified Commit 7c2c3e00 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] [FSDP] Do not lose original reshard_after_forward (#880)

* [fix] [FSDP] Do not lose original reshard_after_forward

- In a corner case we can lose this value
- Saving it and use it in the reset function fixed it
- A trivial case probably not worth a dedicated test for now

* added changelog
parent 1eccb92d
......@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
### Changed
- Fixed a corner case of FSDP init order and losing one of the flags [#880]
## [0.4.3] - 2021-11-18
......
......@@ -307,7 +307,7 @@ class FullyShardedDataParallel(nn.Module):
self.process_group = process_group or get_process_group_cached()
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
self.reshard_after_forward = reshard_after_forward
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
......@@ -1091,6 +1091,7 @@ class FullyShardedDataParallel(nn.Module):
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes
self._output_pre_backward_hook_registered: Optional[List] = None
self.reshard_after_forward = self._orig_reshard_after_forward
def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment