• Darryl Barnhart's avatar
    [fix] FSDP intra-backwards gradient accumulation. (#784) · 4fa2ab9b
    Darryl Barnhart authored
    * [fix] FSDP intra-backwards gradient accumulation.
    
    Ensure gradient reduction accumulates into the unsharded gradient tensor
    within a backwards pass. This matters when an FSDP module is called
    multiple times within a forward pass, and reduction is _not_ deferred
    using activation checkpoint forward counters, bucketing or some other
    mechanism.
    
    Closes #780
    
    * [refactor] Remove forward counters. Comments.
    
    Removed forward counters from the activation checkpointing utility, now
    that FSDP does not require them for correct operation. Add more detailed
    comment about memory usage behaviour with gradient reduction.
    
    * [refactor] Delete deprecated forward counter usage.
    
    * [refactor] Add state assertion as end of pre-backward hook.
    4fa2ab9b
test_sync_batchnorm.py 6.43 KB