Unverified Commit 55e17907 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Add dropout parameter to UNet2DModel/UNet2DConditionModel (#4882)

* Add dropout param to get_down_block/get_up_block and UNet2DModel/UNet2DConditionModel.

* Add dropout param to Versatile Diffusion modeling, which has a copy of UNet2DConditionModel and its own get_down_block/get_up_block functions.
parent c81a88b2
......@@ -70,6 +70,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
......@@ -102,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
......@@ -175,6 +177,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
......@@ -182,6 +185,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
......@@ -215,6 +219,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......
......@@ -55,6 +55,7 @@ def get_down_block(
cross_attention_norm=None,
attention_head_dim=None,
downsample_type=None,
dropout=0.0,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
......@@ -70,6 +71,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -83,6 +85,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -101,6 +104,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
......@@ -118,6 +122,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -140,6 +145,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -158,6 +164,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -170,6 +177,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -181,6 +189,7 @@ def get_down_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -193,6 +202,7 @@ def get_down_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -207,6 +217,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -217,6 +228,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -252,6 +264,7 @@ def get_up_block(
cross_attention_norm=None,
attention_head_dim=None,
upsample_type=None,
dropout=0.0,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
......@@ -268,6 +281,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -281,6 +295,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -299,6 +314,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -321,6 +337,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -345,6 +362,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
......@@ -359,6 +377,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -371,6 +390,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -382,6 +402,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -394,6 +415,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -408,6 +430,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -418,6 +441,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......
......@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
......@@ -178,6 +179,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
......@@ -459,6 +461,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.down_blocks.append(down_block)
......@@ -468,6 +471,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
......@@ -484,6 +488,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
......@@ -550,6 +555,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......
......@@ -58,6 +58,7 @@ def get_down_block(
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
dropout=0.0,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat":
......@@ -66,6 +67,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -81,6 +83,7 @@ def get_down_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -117,6 +120,7 @@ def get_up_block(
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
dropout=0.0,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat":
......@@ -126,6 +130,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -141,6 +146,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
......@@ -284,6 +290,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
......@@ -369,6 +376,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
......@@ -660,6 +668,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.down_blocks.append(down_block)
......@@ -669,6 +678,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
......@@ -685,6 +695,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
......@@ -751,6 +762,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment