Unverified Commit 4c8a05f1 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Fix Consistency Models UNet2DMidBlock2D Attention GroupNorm Bug (#4863)



* Add attn_groups argument to UNet2DMidBlock2D to control theinternal Attention block's GroupNorm.

* Add docstring for attn_norm_num_groups in UNet2DModel.

* Since the test UNet config uses resnet_time_scale_shift == 'scale_shift', also set attn_norm_num_groups to 32.

* Add test for attn_norm_num_groups to UNet2DModelTests.

* Fix expected slices for slow tests.

* Also fix tolerances for slow tests.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5fd42e5d
...@@ -27,6 +27,7 @@ TEST_UNET_CONFIG = { ...@@ -27,6 +27,7 @@ TEST_UNET_CONFIG = {
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
], ],
"resnet_time_scale_shift": "scale_shift", "resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet", "upsample_type": "resnet",
"downsample_type": "resnet", "downsample_type": "resnet",
} }
...@@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = { ...@@ -52,6 +53,7 @@ IMAGENET_64_UNET_CONFIG = {
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
], ],
"resnet_time_scale_shift": "scale_shift", "resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet", "upsample_type": "resnet",
"downsample_type": "resnet", "downsample_type": "resnet",
} }
......
...@@ -74,6 +74,10 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -74,6 +74,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
...@@ -107,6 +111,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -107,6 +111,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
act_fn: str = "silu", act_fn: str = "silu",
attention_head_dim: Optional[int] = 8, attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32, norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
add_attention: bool = True, add_attention: bool = True,
...@@ -192,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -192,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention, add_attention=add_attention,
) )
......
...@@ -485,6 +485,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -485,6 +485,7 @@ class UNetMidBlock2D(nn.Module):
resnet_time_scale_shift: str = "default", # default, spatial resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
add_attention: bool = True, add_attention: bool = True,
attention_head_dim=1, attention_head_dim=1,
...@@ -494,6 +495,9 @@ class UNetMidBlock2D(nn.Module): ...@@ -494,6 +495,9 @@ class UNetMidBlock2D(nn.Module):
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention self.add_attention = add_attention
if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock2D( ResnetBlock2D(
...@@ -526,7 +530,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -526,7 +530,7 @@ class UNetMidBlock2D(nn.Module):
dim_head=attention_head_dim, dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, norm_num_groups=attn_groups,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True, residual_connection=True,
bias=True, bias=True,
......
...@@ -74,6 +74,36 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -74,6 +74,36 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 8
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
)
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
)
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
......
...@@ -216,9 +216,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): ...@@ -216,9 +216,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254]) expected_slice = np.array([0.0146, 0.0158, 0.0092, 0.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0058])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_consistency_model_cd_onestep(self): def test_consistency_model_cd_onestep(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2") unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
...@@ -239,9 +239,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): ...@@ -239,9 +239,9 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217]) expected_slice = np.array([0.0059, 0.0003, 0.0000, 0.0023, 0.0052, 0.0007, 0.0165, 0.0081, 0.0095])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@require_torch_2 @require_torch_2
def test_consistency_model_cd_multistep_flash_attn(self): def test_consistency_model_cd_multistep_flash_attn(self):
...@@ -263,7 +263,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): ...@@ -263,7 +263,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353]) expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
...@@ -289,6 +289,6 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase): ...@@ -289,6 +289,6 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095]) expected_slice = np.array([0.1623, 0.2009, 0.2387, 0.1731, 0.1168, 0.1202, 0.2031, 0.1327, 0.2447])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment