[fix] add clear_autocast_cache flag (#650)
* [fix] add clear_autocast_cache flag
- when training in AMP model with weight dtype32, FSDP may need to
optionally clear the autocast cache to avoid GPU OOM
- this flag is default false, automatically doing it is a future TODO
- also added a verbose flag to make print(fsdp_model) a bit shorter
- updated the memory test to cover those new code
- added a couple of useful functions in parallel.py and testing.py
* minor
* address comments
* format
* improve the test
Co-authored-by:
Min Xu <min.xu@acm.org>
Showing
Please register or sign in to comment