[fix] better handling non-flatten in FSDP (#1072)
* [fix] better handling non-flatten in FSDP
- see the detailed comment about that backward firing case
- also minor debugging help in FSDP
- also minor fix in FPW's state dict
* [feat] disallow reset_parameters by default
* [feat] adding fsdp_instances API - useful in check wrapping by user code
* [fix] one line fix but more than a day of debugging
* fixed the case of loading combined check with empty fsdp instances
* fixed another bug around state loading the root/nonroot module full param caching due to not resharding after forward
* [feat] support .half and .float better
* fixed a bug in gather optim state losses extra keys from the original state_dict
* fixed a test failure in mixed precision
* fixed another bug affecting no_sync grad acc
* fixed a bug and a test in fsdp optim state
* fixed another corner case
* added a comment
* skip ssd offload tests
* skip fsdp one for ssd overload
Co-authored-by:
Min Xu <min.xu.public@gmail.com>
Showing
Please register or sign in to comment