Unverified Commit e55687e1 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

unet check length inputs (#2327)



* unet check length input

* prep test file for changes

* correct all tests

* clean up

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9e8ee2ac
......@@ -94,7 +94,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
act_fn: str = "silu",
attention_head_dim: int = 8,
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
......@@ -107,6 +107,17 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
......
......@@ -150,6 +150,27 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
......
......@@ -236,6 +236,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
f" {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
f" {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
"Must provide the same number of `only_cross_attention` as `down_block_types`."
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
f" {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
......
......@@ -56,7 +56,7 @@ class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
# SD2-specific config below
attention_head_dim=(2, 4, 8, 8),
attention_head_dim=(2, 4),
use_linear_projection=True,
)
scheduler = DDIMScheduler(
......
......@@ -65,7 +65,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
attention_head_dim=(2, 4, 8, 8),
attention_head_dim=(2, 4),
use_linear_projection=True,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
......@@ -284,7 +284,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps":
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
else:
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
......@@ -305,7 +305,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps":
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
else:
expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621])
expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
......@@ -327,7 +327,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps":
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
else:
expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681])
expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
......@@ -382,7 +382,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps":
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
else:
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
......
......@@ -47,7 +47,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
# SD2-specific config below
attention_head_dim=(2, 4, 8, 8),
attention_head_dim=(2, 4),
use_linear_projection=True,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
......
......@@ -56,7 +56,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
# SD2-specific config below
attention_head_dim=(2, 4, 8, 8),
attention_head_dim=(2, 4),
use_linear_projection=True,
)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment