Unverified Commit 55e17907 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Add dropout parameter to UNet2DModel/UNet2DConditionModel (#4882)

* Add dropout param to get_down_block/get_up_block and UNet2DModel/UNet2DConditionModel.

* Add dropout param to Versatile Diffusion modeling, which has a copy of UNet2DConditionModel and its own get_down_block/get_up_block functions.
parent c81a88b2
...@@ -70,6 +70,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -70,6 +70,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
The downsample type for downsampling layers. Choose between "conv" and "resnet" The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`): upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet" The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
...@@ -102,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -102,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
downsample_padding: int = 1, downsample_padding: int = 1,
downsample_type: str = "conv", downsample_type: str = "conv",
upsample_type: str = "conv", upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu", act_fn: str = "silu",
attention_head_dim: Optional[int] = 8, attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32, norm_num_groups: int = 32,
...@@ -175,6 +177,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -175,6 +177,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type, downsample_type=downsample_type,
dropout=dropout,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -182,6 +185,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -182,6 +185,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.mid_block = UNetMidBlock2D( self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
...@@ -215,6 +219,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -215,6 +219,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type, upsample_type=upsample_type,
dropout=dropout,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
......
...@@ -55,6 +55,7 @@ def get_down_block( ...@@ -55,6 +55,7 @@ def get_down_block(
cross_attention_norm=None, cross_attention_norm=None,
attention_head_dim=None, attention_head_dim=None,
downsample_type=None, downsample_type=None,
dropout=0.0,
): ):
# If attn head dim is not defined, we default it to the number of heads # If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None: if attention_head_dim is None:
...@@ -70,6 +71,7 @@ def get_down_block( ...@@ -70,6 +71,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -83,6 +85,7 @@ def get_down_block( ...@@ -83,6 +85,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -101,6 +104,7 @@ def get_down_block( ...@@ -101,6 +104,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
...@@ -118,6 +122,7 @@ def get_down_block( ...@@ -118,6 +122,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -140,6 +145,7 @@ def get_down_block( ...@@ -140,6 +145,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -158,6 +164,7 @@ def get_down_block( ...@@ -158,6 +164,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -170,6 +177,7 @@ def get_down_block( ...@@ -170,6 +177,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -181,6 +189,7 @@ def get_down_block( ...@@ -181,6 +189,7 @@ def get_down_block(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -193,6 +202,7 @@ def get_down_block( ...@@ -193,6 +202,7 @@ def get_down_block(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -207,6 +217,7 @@ def get_down_block( ...@@ -207,6 +217,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -217,6 +228,7 @@ def get_down_block( ...@@ -217,6 +228,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -252,6 +264,7 @@ def get_up_block( ...@@ -252,6 +264,7 @@ def get_up_block(
cross_attention_norm=None, cross_attention_norm=None,
attention_head_dim=None, attention_head_dim=None,
upsample_type=None, upsample_type=None,
dropout=0.0,
): ):
# If attn head dim is not defined, we default it to the number of heads # If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None: if attention_head_dim is None:
...@@ -268,6 +281,7 @@ def get_up_block( ...@@ -268,6 +281,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -281,6 +295,7 @@ def get_up_block( ...@@ -281,6 +295,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -299,6 +314,7 @@ def get_up_block( ...@@ -299,6 +314,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -321,6 +337,7 @@ def get_up_block( ...@@ -321,6 +337,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -345,6 +362,7 @@ def get_up_block( ...@@ -345,6 +362,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
...@@ -359,6 +377,7 @@ def get_up_block( ...@@ -359,6 +377,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -371,6 +390,7 @@ def get_up_block( ...@@ -371,6 +390,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -382,6 +402,7 @@ def get_up_block( ...@@ -382,6 +402,7 @@ def get_up_block(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -394,6 +415,7 @@ def get_up_block( ...@@ -394,6 +415,7 @@ def get_up_block(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -408,6 +430,7 @@ def get_up_block( ...@@ -408,6 +430,7 @@ def get_up_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -418,6 +441,7 @@ def get_up_block( ...@@ -418,6 +441,7 @@ def get_up_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
......
...@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing. If `None`, normalization and activation layers is skipped in post-processing.
...@@ -178,6 +179,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -178,6 +179,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu", act_fn: str = "silu",
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
...@@ -459,6 +461,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -459,6 +461,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -468,6 +471,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -468,6 +471,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
transformer_layers_per_block=transformer_layers_per_block[-1], transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim, temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
...@@ -484,6 +488,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -484,6 +488,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.mid_block = UNetMidBlock2DSimpleCrossAttn( self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim, temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
...@@ -550,6 +555,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -550,6 +555,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
......
...@@ -58,6 +58,7 @@ def get_down_block( ...@@ -58,6 +58,7 @@ def get_down_block(
resnet_skip_time_act=False, resnet_skip_time_act=False,
resnet_out_scale_factor=1.0, resnet_out_scale_factor=1.0,
cross_attention_norm=None, cross_attention_norm=None,
dropout=0.0,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlockFlat": if down_block_type == "DownBlockFlat":
...@@ -66,6 +67,7 @@ def get_down_block( ...@@ -66,6 +67,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -81,6 +83,7 @@ def get_down_block( ...@@ -81,6 +83,7 @@ def get_down_block(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -117,6 +120,7 @@ def get_up_block( ...@@ -117,6 +120,7 @@ def get_up_block(
resnet_skip_time_act=False, resnet_skip_time_act=False,
resnet_out_scale_factor=1.0, resnet_out_scale_factor=1.0,
cross_attention_norm=None, cross_attention_norm=None,
dropout=0.0,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlockFlat": if up_block_type == "UpBlockFlat":
...@@ -126,6 +130,7 @@ def get_up_block( ...@@ -126,6 +130,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -141,6 +146,7 @@ def get_up_block( ...@@ -141,6 +146,7 @@ def get_up_block(
out_channels=out_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
dropout=dropout,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
...@@ -284,6 +290,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -284,6 +290,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing. If `None`, normalization and activation layers is skipped in post-processing.
...@@ -369,6 +376,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -369,6 +376,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu", act_fn: str = "silu",
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
...@@ -660,6 +668,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -660,6 +668,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -669,6 +678,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -669,6 +678,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
transformer_layers_per_block=transformer_layers_per_block[-1], transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim, temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
...@@ -685,6 +695,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -685,6 +695,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.mid_block = UNetMidBlockFlatSimpleCrossAttn( self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim, temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
...@@ -751,6 +762,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -751,6 +762,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment