Unverified Commit 174dcd69 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Model API (#3562)

* add modelmixin and unets

* remove old model page

* minor fixes

* fix unet2dcondition

* add vqmodel and autoencoderkl

* add rest of models

* fix autoencoderkl path

* fix toctree

* fix toctree again

* apply feedback

* apply feedback

* fix copies

* fix controlnet copy

* fix copies
parent cdf2ae8a
...@@ -26,9 +26,11 @@ from .modeling_utils import ModelMixin ...@@ -26,9 +26,11 @@ from .modeling_utils import ModelMixin
@dataclass @dataclass
class TransformerTemporalModelOutput(BaseOutput): class TransformerTemporalModelOutput(BaseOutput):
""" """
The output of [`TransformerTemporalModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. The hidden states output conditioned on `encoder_hidden_states` input.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
...@@ -36,24 +38,23 @@ class TransformerTemporalModelOutput(BaseOutput): ...@@ -36,24 +38,23 @@ class TransformerTemporalModelOutput(BaseOutput):
class TransformerTemporalModel(ModelMixin, ConfigMixin): class TransformerTemporalModel(ModelMixin, ConfigMixin):
""" """
Transformer model for video-like data. A Transformer model for video-like data.
Parameters: Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*): in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output. The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
Note that this is fixed at training time as it is used for learning a number of position embeddings. See This is fixed during training since it is used to learn a number of position embeddings.
`ImagePositionalEmbeddings`. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (`bool`, *optional*): attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter. Configure if the `TransformerBlock` attention should contain a bias parameter.
double_self_attention (`bool`, *optional*): double_self_attention (`bool`, *optional*):
Configure if each TransformerBlock should contain two self-attention layers Configure if each `TransformerBlock` should contain two self-attention layers.
""" """
@register_to_config @register_to_config
...@@ -114,25 +115,27 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -114,25 +115,27 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
): ):
""" """
The [`TransformerTemporal`] forward method.
Args: Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input Input hidden_states.
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.long`, *optional*): timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
conditioning. `AdaLayerZeroNorm`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns: Returns:
[`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
[`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
When returning a tuple, the first element is the sample tensor. returned, otherwise a `tuple` where the first element is the sample tensor.
""" """
# 1. Input # 1. Input
batch_frames, channel, height, width = hidden_states.shape batch_frames, channel, height, width = hidden_states.shape
......
...@@ -28,9 +28,11 @@ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up ...@@ -28,9 +28,11 @@ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up
@dataclass @dataclass
class UNet1DOutput(BaseOutput): class UNet1DOutput(BaseOutput):
""" """
The output of [`UNet1DModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
Hidden states output. Output of last layer of model. The hidden states output from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
...@@ -38,10 +40,10 @@ class UNet1DOutput(BaseOutput): ...@@ -38,10 +40,10 @@ class UNet1DOutput(BaseOutput):
class UNet1DModel(ModelMixin, ConfigMixin): class UNet1DModel(ModelMixin, ConfigMixin):
r""" r"""
UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the model (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
...@@ -49,24 +51,24 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -49,24 +51,24 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
extra_in_channels (`int`, *optional*, defaults to 0): extra_in_channels (`int`, *optional*, defaults to 0):
Number of additional channels to be added to the input of the first down block. Useful for cases where the Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model is initially designed for. input data has more channels than what the model was initially designed for.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to : flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
obj:`False`): Whether to flip sin to cos for fourier time embedding. Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to : down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to : up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to : block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
obj:`(32, 32, 64)`): Tuple of block output channels. Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
act_fn (`str`, *optional*, defaults to None): optional activation function in UNet blocks. act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
downsample_each_block (`int`, *optional*, defaults to False: downsample_each_block (`int`, *optional*, defaults to `False`):
experimental feature for using a UNet without upsampling. Experimental feature for using a UNet without upsampling.
""" """
@register_to_config @register_to_config
...@@ -197,15 +199,19 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -197,15 +199,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]: ) -> Union[UNet1DOutput, Tuple]:
r""" r"""
The [`UNet1DModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor sample (`torch.FloatTensor`):
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns: Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, [`~models.unet_1d.UNet1DOutput`] or `tuple`:
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
""" """
# 1. time # 1. time
......
...@@ -27,9 +27,11 @@ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block ...@@ -27,9 +27,11 @@ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass @dataclass
class UNet2DOutput(BaseOutput): class UNet2DOutput(BaseOutput):
""" """
The output of [`UNet2DModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states output. Output of last layer of model. The hidden states output from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
...@@ -37,46 +39,45 @@ class UNet2DOutput(BaseOutput): ...@@ -37,46 +39,45 @@ class UNet2DOutput(BaseOutput):
class UNet2DModel(ModelMixin, ConfigMixin): class UNet2DModel(ModelMixin, ConfigMixin):
r""" r"""
UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the model (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`. 1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to : flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
obj:`True`): Whether to flip sin to cos for fourier time embedding. Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to : down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block Tuple of downsample block types.
types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to : up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to : block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
obj:`(224, 448, 672, 896)`): Tuple of block output channels. Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`. `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
class conditioning with `class_embed_type` equal to `None`. conditioning with `class_embed_type` equal to `None`.
""" """
@register_to_config @register_to_config
...@@ -224,17 +225,21 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -224,17 +225,21 @@ class UNet2DModel(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]: ) -> Union[UNet2DOutput, Tuple]:
r""" r"""
The [`UNet2DModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`):
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns: Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, [`~models.unet_2d.UNet2DOutput`] or `tuple`:
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
""" """
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
......
...@@ -50,9 +50,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -50,9 +50,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass @dataclass
class UNet2DConditionOutput(BaseOutput): class UNet2DConditionOutput(BaseOutput):
""" """
The output of [`UNet2DConditionModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
...@@ -60,17 +62,17 @@ class UNet2DConditionOutput(BaseOutput): ...@@ -60,17 +62,17 @@ class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
and returns sample shaped output. shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the models (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
...@@ -78,9 +80,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -78,9 +80,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
mid block layer if `None`. `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see Whether to include self-attention in the basic transformer blocks, see
...@@ -92,52 +94,52 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -92,52 +94,52 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None): encoder_hid_dim (`int`, *optional*, defaults to `None`):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`. dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to None): encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*): num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim` The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to None): addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer. "text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`): time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, default to `None`): time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding. An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, default to `None`): time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest Optional activation function to use only once on the time embeddings before they are passed to the rest of
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`): timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`): time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in timestep embedding. The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings. embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
default to `False`. otherwise.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -551,11 +553,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -551,11 +553,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention.
Parameters: Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. for **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -589,15 +595,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -589,15 +595,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention When this option is enabled, the attention module splits the input tensor in slices to compute attention in
in several steps. This is useful to save some memory in exchange for a small speed decrease. several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `num_attention_heads // slice_size`. In this case, provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
`num_attention_heads` must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
...@@ -670,29 +676,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -670,29 +676,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
The [`UNet2DConditionModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`):
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps The noisy input tensor with the following shape `(batch, channel, height, width)`.
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
encoder_attention_mask (`torch.Tensor`): encoder_attention_mask (`torch.Tensor`):
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
discard. Mask will be converted into a bias, which adds large negative values to attention scores `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
corresponding to "discard" tokens. which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
added_cond_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
`addition_embed_type` for more information.
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
returning a tuple, the first element is the sample tensor. a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
......
...@@ -35,9 +35,11 @@ from .unet_2d_blocks_flax import ( ...@@ -35,9 +35,11 @@ from .unet_2d_blocks_flax import (
@flax.struct.dataclass @flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput): class FlaxUNet2DConditionOutput(BaseOutput):
""" """
The output of [`FlaxUNet2DConditionModel`].
Args: Args:
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
""" """
sample: jnp.ndarray sample: jnp.ndarray
...@@ -46,17 +48,17 @@ class FlaxUNet2DConditionOutput(BaseOutput): ...@@ -46,17 +48,17 @@ class FlaxUNet2DConditionOutput(BaseOutput):
@flax_register_to_config @flax_register_to_config
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
r""" r"""
FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
timestep and returns sample shaped output. shaped output.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
implements for all the models (such as downloading or saving, etc.) implemented for all models (such as downloading or saving).
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
general usage and behavior. general usage and behavior.
Finally, this model supports inherent JAX features such as: Inherent JAX features such as the following are supported:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
...@@ -69,12 +71,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -69,12 +71,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The number of channels in the input sample. The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output. The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", The tuple of downsample blocks to use.
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): The tuple of upsample blocks to use.
The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
...@@ -91,8 +91,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -91,8 +91,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682 Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
""" """
sample_size: int = 32 sample_size: int = 32
......
...@@ -43,9 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -43,9 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass @dataclass
class UNet3DConditionOutput(BaseOutput): class UNet3DConditionOutput(BaseOutput):
""" """
The output of [`UNet3DConditionModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
...@@ -53,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput): ...@@ -53,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput):
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
and returns sample shaped output. shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the models (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
...@@ -66,7 +68,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -66,7 +68,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
...@@ -75,7 +77,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -75,7 +77,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
...@@ -291,15 +293,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -291,15 +293,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention When this option is enabled, the attention module splits the input tensor in slices to compute attention in
in several steps. This is useful to save some memory in exchange for a small speed decrease. several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `num_attention_heads // slice_size`. In this case, provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
`num_attention_heads` must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
...@@ -355,11 +357,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -355,11 +357,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention.
Parameters: Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. for **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -408,21 +414,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -408,21 +414,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]: ) -> Union[UNet3DConditionOutput, Tuple]:
r""" r"""
The [`UNet3DConditionModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`):
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Returns: Returns:
[`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
returning a tuple, the first element is the sample tensor. a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
......
...@@ -30,7 +30,7 @@ class DecoderOutput(BaseOutput): ...@@ -30,7 +30,7 @@ class DecoderOutput(BaseOutput):
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model. The decoded output sample from the last layer of the model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
......
...@@ -36,9 +36,9 @@ class FlaxDecoderOutput(BaseOutput): ...@@ -36,9 +36,9 @@ class FlaxDecoderOutput(BaseOutput):
Args: Args:
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model. The decoded output sample from the last layer of the model.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
Parameters `dtype` The `dtype` of the parameters.
""" """
sample: jnp.ndarray sample: jnp.ndarray
...@@ -720,40 +720,43 @@ class FlaxDiagonalGaussianDistribution(object): ...@@ -720,40 +720,43 @@ class FlaxDiagonalGaussianDistribution(object):
@flax_register_to_config @flax_register_to_config
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
r""" r"""
Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Flax implementation of a VAE model with KL loss for decoding latent representations.
Bayes by Diederik P. Kingma and Max Welling.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
implemented for all models (such as downloading or saving).
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its
general usage and behavior. general usage and behavior.
Finally, this model supports inherent JAX features such as: Inherent JAX features such as the following are supported:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters: Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3): in_channels (`int`, *optional*, defaults to 3):
Input channels Number of channels in the input image.
out_channels (:obj:`int`, *optional*, defaults to 3): out_channels (`int`, *optional*, defaults to 3):
Output channels Number of channels in the output.
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
DownEncoder block type Tuple of downsample block types.
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
UpDecoder block type Tuple of upsample block types.
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block Tuple of block output channels.
layers_per_block (:obj:`int`, *optional*, defaults to `2`): layers_per_block (`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block Number of ResNet layer for each block.
act_fn (:obj:`str`, *optional*, defaults to `silu`): act_fn (`str`, *optional*, defaults to `silu`):
Activation function The activation function to use.
latent_channels (:obj:`int`, *optional*, defaults to `4`): latent_channels (`int`, *optional*, defaults to `4`):
Latent space channels Number of channels in the latent space.
norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm_num_groups (`int`, *optional*, defaults to `32`):
Norm num group The number of groups for normalization.
sample_size (:obj:`int`, *optional*, defaults to 32): sample_size (`int`, *optional*, defaults to 32):
Sample input size Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215): scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion training set. This is used to scale the latent space to have unit variance when training the diffusion
...@@ -761,8 +764,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -761,8 +764,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
parameters `dtype` The `dtype` of the parameters.
""" """
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
......
...@@ -30,31 +30,31 @@ class VQEncoderOutput(BaseOutput): ...@@ -30,31 +30,31 @@ class VQEncoderOutput(BaseOutput):
Args: Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model. The encoded output sample from the last layer of the model.
""" """
latents: torch.FloatTensor latents: torch.FloatTensor
class VQModel(ModelMixin, ConfigMixin): class VQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray r"""
Kavukcuoglu. A VQ-VAE model for decoding latent representations.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the model (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image. in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output. out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to : down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to : up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to : block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
obj:`(64,)`): Tuple of block output channels. Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO sample_size (`int`, *optional*, defaults to `32`): Sample input size.
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`): scaling_factor (`float`, *optional*, defaults to `0.18215`):
...@@ -143,10 +143,17 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -143,10 +143,17 @@ class VQModel(ModelMixin, ConfigMixin):
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r""" r"""
The [`VQModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
Returns:
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
""" """
x = sample x = sample
h = self.encode(x).latents h = self.encode(x).latents
......
...@@ -153,17 +153,17 @@ def get_up_block( ...@@ -153,17 +153,17 @@ def get_up_block(
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin): class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r""" r"""
UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
timestep and returns sample shaped output. shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the models (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
...@@ -171,9 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -171,9 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
the mid block layer if `None`. `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see Whether to include self-attention in the basic transformer blocks, see
...@@ -185,52 +185,52 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -185,52 +185,52 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None): encoder_hid_dim (`int`, *optional*, defaults to `None`):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`. dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to None): encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*): num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim` The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to None): addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer. "text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`): time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, default to `None`): time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding. An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, default to `None`): time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest Optional activation function to use only once on the time embeddings before they are passed to the rest of
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`): timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`): time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in timestep embedding. The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings. embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
default to `False`. otherwise.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -656,11 +656,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -656,11 +656,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention.
Parameters: Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. for **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -694,15 +698,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -694,15 +698,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention When this option is enabled, the attention module splits the input tensor in slices to compute attention in
in several steps. This is useful to save some memory in exchange for a small speed decrease. several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `num_attention_heads // slice_size`. In this case, provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
`num_attention_heads` must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
...@@ -775,29 +779,28 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -775,29 +779,28 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
The [`UNetFlatConditionModel`] forward method.
Args: Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`):
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps The noisy input tensor with the following shape `(batch, channel, height, width)`.
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
encoder_attention_mask (`torch.Tensor`): encoder_attention_mask (`torch.Tensor`):
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
discard. Mask will be converted into a bias, which adds large negative values to attention scores `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
corresponding to "discard" tokens. which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
added_cond_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
`addition_embed_type` for more information.
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
returning a tuple, the first element is the sample tensor. a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment