Unverified Commit b13cdbb2 authored by hlky's avatar hlky Committed by GitHub
Browse files

UNet2DModel mid_block_type (#10469)

parent a0acbdc9
......@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
......@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
......@@ -194,19 +195,22 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
if mid_block_type is None:
self.mid_block = None
else:
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
......@@ -322,7 +326,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
if self.mid_block is not None:
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
......
......@@ -105,6 +105,35 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
mid_none_model = self.model_class(**mid_none_init_dict)
mid_none_model.to(torch_device)
mid_none_model.eval()
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
with torch.no_grad():
mid_none_output = mid_none_model(**mid_none_inputs_dict)
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment