• Karim Foda's avatar
    Flax Remat for LongT5 (#17994) · d6eeb871
    Karim Foda authored
    
    
    * [Flax] Add remat (gradient checkpointing)
    
    * fix variable naming in test
    
    * flip: checkpoint using a method
    
    * fix naming
    
    * fix class naming
    
    * apply PVP's suggestions from code review
    
    * add gradient_checkpointing to examples
    
    * Add gradient_checkpointing to run_mlm_flax
    
    * Add remat to longt5
    
    * Add gradient checkpointing test longt5
    
    * Fix args errors
    
    * Fix remaining tests
    
    * Make fixup & quality fixes
    
    * replace kwargs
    
    * remove unecessary kwargs
    
    * Make fixup changes
    
    * revert long_t5_flax changes
    
    * Remove return_dict and copy to LongT5
    
    * Remove test_gradient_checkpointing
    Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
    d6eeb871
run_summarization_flax.py 40.7 KB