Unverified Commit b13cdbb2 authored by hlky's avatar hlky Committed by GitHub
Browse files

UNet2DModel mid_block_type (#10469)

parent a0acbdc9
...@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types. Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
...@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
freq_shift: int = 0, freq_shift: int = 0,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2, layers_per_block: int = 2,
...@@ -194,6 +195,9 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -194,6 +195,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
# mid # mid
if mid_block_type is None:
self.mid_block = None
else:
self.mid_block = UNetMidBlock2D( self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
...@@ -322,6 +326,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -322,6 +326,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid # 4. mid
if self.mid_block is not None:
sample = self.mid_block(sample, emb) sample = self.mid_block(sample, emb)
# 5. up # 5. up
......
...@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
mid_none_model = self.model_class(**mid_none_init_dict)
mid_none_model.to(torch_device)
mid_none_model.eval()
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
with torch.no_grad():
mid_none_output = mid_none_model(**mid_none_inputs_dict)
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = { expected_set = {
"AttnUpBlock2D", "AttnUpBlock2D",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment