• Vishnu V Jaddipal's avatar
    Added support to create asymmetrical U-Net structures (#5400) · 8dba1808
    Vishnu V Jaddipal authored
    
    
    * Added args, kwargs to ```U
    
    * Add UNetMidBlock2D as a supported mid block type
    
    * Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_blocks.py
    
    * Update unet_2d_blocks.py
    
    * Update unet_2d_blocks.py
    
    * Update unet_2d_condition.py
    
    * Update unet_2d_blocks.py
    
    * Updated docstring, increased check strictness
    
    Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```
    
    * Add basic shape-check test for asymmetrical unets
    
    * Update src/diffusers/models/unet_2d_blocks.py
    
    Removed blank line
    Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
    
    * Update unet_2d_condition.py
    
    Remove blank space
    
    * Update unet_2d_condition.py
    
    Changed docstring for `mid_block_type`
    
    * Fixed docstring and wrong default value
    
    * Reformat with black
    
    * Reformat with necessary commands
    
    * Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency
    
    * Removed args, kwargs, use on mid-block type
    
    * Make fix-copies
    
    * Update src/diffusers/models/unet_2d_condition.py
    
    Wrap into single line
    Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
    
    * make fix-copies
    
    ---------
    Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
    8dba1808
unet_2d_condition.py 60.1 KB