• Suraj Patil's avatar
    [UNet2DConditionModel] add gradient checkpointing (#461) · e7120bae
    Suraj Patil authored
    * add grad ckpt to downsample blocks
    
    * make it work
    
    * don't pass gradient_checkpointing to upsample block
    
    * add tests for UNet2DConditionModel
    
    * add test_gradient_checkpointing
    
    * add gradient_checkpointing for up and down blocks
    
    * add functions to enable and disable grad ckpt
    
    * remove the forward argument
    
    * better naming
    
    * make supports_gradient_checkpointing private
    e7120bae
unet_blocks.py 53.9 KB