• dg845's avatar
    Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863) · 4c8a05f1
    dg845 authored
    
    
    * Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm.
    
    * Add docstring for attn_norm_num_groups in UNet2DModel.
    
    * Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32.
    
    * Add test for attn_norm_num_groups to UNet2DModelTests.
    
    * Fix expected slices for slow tests.
    
    * Also fix tolerances for slow tests.
    
    ---------
    Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
    4c8a05f1
unet_2d.py 15.8 KB