• Sanchit Gandhi's avatar
    [Flax] Add remat (gradient checkpointing) (#17843) · 485bbe79
    Sanchit Gandhi authored
    * [Flax] Add remat (gradient checkpointing)
    
    * fix variable naming in test
    
    * flip: checkpoint using a method
    
    * fix naming
    
    * fix class naming
    
    * apply PVP's suggestions from code review
    
    * make fix-copies
    
    * fix big-bird, electra, roberta
    
    * cookie-cutter
    
    * fix flax big-bird
    
    * move test to common
    485bbe79
test_modeling_flax_common.py 52.9 KB