Commit 8db5e5b3 authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

allow unet varying number of layers per block

parent 707341ae
......@@ -132,7 +132,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
......@@ -186,6 +186,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
......@@ -260,6 +265,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
......@@ -277,7 +285,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
......@@ -338,6 +346,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
......@@ -358,7 +367,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
......
......@@ -218,7 +218,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
......@@ -277,6 +277,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
f" {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
......@@ -351,6 +357,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
......@@ -368,7 +377,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
......@@ -429,6 +438,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
......@@ -449,7 +459,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment