@@ -512,17 +512,24 @@ class UNetMidBlock2D(nn.Module):
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`): The type of normalization to apply to the time embeddings. This can help to improve the performance of the model on tasks with long-range temporal dependencies.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`): Whether to use pre-normalization for the resnet blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`.
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,