FSDP: Fix saving and loading checkpoints with use_sharded_state=True (#574)
* fix saving and loading checkpoints with use_sharded_state=True * mypy fix * better fix of the infinite recursion - we need to specifically call FSDP.state_dict from its local state_dict - added unit test that fails without the fix and works with the fix - fixed mypy for the overloaded functions * make cpu-only fsdp work for state_dict at least Co-authored-by:Min Xu <min.xu@acm.org> Co-authored-by:
Min Xu <min.xu.public@gmail.com> Co-authored-by:
Min Xu <m1n@fb.com>
Showing
Please register or sign in to comment