• Min Xu's avatar
    [fix] add clear_autocast_cache flag (#650) · 861b5ce2
    Min Xu authored
    
    
    * [fix] add clear_autocast_cache flag
    
    - when training in AMP model with weight dtype32, FSDP may need to
      optionally clear the autocast cache to avoid GPU OOM
    - this flag is default false, automatically doing it is a future TODO
    - also added a verbose flag to make print(fsdp_model) a bit shorter
    - updated the memory test to cover those new code
    - added a couple of useful functions in parallel.py and testing.py
    
    * minor
    
    * address comments
    
    * format
    
    * improve the test
    Co-authored-by: default avatarMin Xu <min.xu@acm.org>
    861b5ce2
test_fsdp_memory.py 14.5 KB