Unverified Commit b13cdbb2 authored by hlky's avatar hlky Committed by GitHub
Browse files

UNet2DModel mid_block_type (#10469)

parent a0acbdc9
...@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types. Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
...@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
freq_shift: int = 0, freq_shift: int = 0,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2, layers_per_block: int = 2,
...@@ -194,19 +195,22 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -194,19 +195,22 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
# mid # mid
self.mid_block = UNetMidBlock2D( if mid_block_type is None:
in_channels=block_out_channels[-1], self.mid_block = None
temb_channels=time_embed_dim, else:
dropout=dropout, self.mid_block = UNetMidBlock2D(
resnet_eps=norm_eps, in_channels=block_out_channels[-1],
resnet_act_fn=act_fn, temb_channels=time_embed_dim,
output_scale_factor=mid_block_scale_factor, dropout=dropout,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_eps=norm_eps,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor,
attn_groups=attn_norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=add_attention, attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
) resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
...@@ -322,7 +326,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -322,7 +326,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid # 4. mid
sample = self.mid_block(sample, emb) if self.mid_block is not None:
sample = self.mid_block(sample, emb)
# 5. up # 5. up
skip_sample = None skip_sample = None
......
...@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
mid_none_model = self.model_class(**mid_none_init_dict)
mid_none_model.to(torch_device)
mid_none_model.eval()
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
with torch.no_grad():
mid_none_output = mid_none_model(**mid_none_inputs_dict)
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = { expected_set = {
"AttnUpBlock2D", "AttnUpBlock2D",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment