Unverified Commit 5e6417e9 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Docs] Models (#416)



* docs for attention

* types for embeddings

* unet2d docstrings

* UNet2DConditionModel docstrings

* fix typos

* style and vq-vae docstrings

* docstrings  for VAE

* Update src/diffusers/models/unet_2d.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make style

* added inherits from sentence

* docstring to forward

* make style

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* finish model docs

* up
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 234e90cc
...@@ -16,13 +16,32 @@ Diffusers contains pretrained models for popular algorithms and modules for crea ...@@ -16,13 +16,32 @@ Diffusers contains pretrained models for popular algorithms and modules for crea
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
## API ## ModelMixin
[[autodoc]] ModelMixin
Models should provide the `def forward` function and initialization of the model. ## UNet2DOutput
All saving, loading, and utilities should be in the base ['ModelMixin'] class. [[autodoc]] models.unet_2d.UNet2DOutput
## Examples ## UNet2DModel
[[autodoc]] UNet2DModel
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. ## UNet2DConditionOutput
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). [[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
- TODO: mention VAE / SDE score estimation
\ No newline at end of file ## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
## VQEncoderOutput
[[autodoc]] models.vae.VQEncoderOutput
## VQModel
[[autodoc]] VQModel
## AutoencoderKLOutput
[[autodoc]] models.vae.AutoencoderKLOutput
## AutoencoderKL
[[autodoc]] AutoencoderKL
...@@ -117,27 +117,12 @@ class ModelMixin(torch.nn.Module): ...@@ -117,27 +117,12 @@ class ModelMixin(torch.nn.Module):
Base class for all models. Base class for all models.
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models as well as a few methods common to all models to: and saving models.
- resize the input embeddings, Class attributes:
- prune heads in the self-attention heads.
Class attributes (overridden by derived classes): - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~modeling_utils.ModelMixin.save_pretrained`].
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
...@@ -150,11 +135,10 @@ class ModelMixin(torch.nn.Module): ...@@ -150,11 +135,10 @@ class ModelMixin(torch.nn.Module):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
save_function: Callable = torch.save, save_function: Callable = torch.save,
**kwargs,
): ):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~ModelMixin.from_pretrained`]` class method. `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -166,9 +150,6 @@ class ModelMixin(torch.nn.Module): ...@@ -166,9 +150,6 @@ class ModelMixin(torch.nn.Module):
save_function (`Callable`): save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. need to replace `torch.save` by another method.
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -224,34 +205,12 @@ class ModelMixin(torch.nn.Module): ...@@ -224,34 +205,12 @@ class ModelMixin(torch.nn.Module):
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`. e.g., `./my_model_directory/`.
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
Can be either:
- an instance of a class derived from [`ConfigMixin`],
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
ConfigMixinuration can be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
cache_dir (`Union[str, os.PathLike]`, *optional*): cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
from_tf (`bool`, *optional*, defaults to `False`): torch_dtype (`str` or `torch.dtype`, *optional*):
Load the model weights from a TensorFlow checkpoint save file (see docstring of Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
`pretrained_model_name_or_path` argument). will be automatically derived from the model's weights.
from_flax (`bool`, *optional*, defaults to `False`):
Load the model weights from a Flax checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`): force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist. cached versions if they exist.
...@@ -267,7 +226,7 @@ class ModelMixin(torch.nn.Module): ...@@ -267,7 +226,7 @@ class ModelMixin(torch.nn.Module):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*): use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`). when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
...@@ -278,18 +237,7 @@ class ModelMixin(torch.nn.Module): ...@@ -278,18 +237,7 @@ class ModelMixin(torch.nn.Module):
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
to a configuration attribute will be used to override said attribute with the supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
underlying model's `__init__` function.
<Tip> <Tip>
...@@ -299,8 +247,8 @@ class ModelMixin(torch.nn.Module): ...@@ -299,8 +247,8 @@ class ModelMixin(torch.nn.Module):
<Tip> <Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
use this method in a firewalled environment. this method in a firewalled environment.
</Tip> </Tip>
...@@ -404,7 +352,7 @@ class ModelMixin(torch.nn.Module): ...@@ -404,7 +352,7 @@ class ModelMixin(torch.nn.Module):
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or" f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in" " \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
raise EnvironmentError( raise EnvironmentError(
......
import math import math
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module): ...@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case. to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
""" """
def __init__( def __init__(
self, self,
channels, channels: int,
num_head_channels=None, num_head_channels: Optional[int] = None,
num_groups=32, num_groups: int = 32,
rescale_output_factor=1.0, rescale_output_factor: float = 1.0,
eps=1e-5, eps: float = 1e-5,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -86,10 +95,26 @@ class AttentionBlock(nn.Module): ...@@ -86,10 +95,26 @@ class AttentionBlock(nn.Module):
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
""" """
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
""" """
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
context_dim: Optional[int] = None,
):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
...@@ -127,7 +152,29 @@ class SpatialTransformer(nn.Module): ...@@ -127,7 +152,29 @@ class SpatialTransformer(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__() super().__init__()
self.attn1 = CrossAttention( self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
...@@ -154,7 +201,21 @@ class BasicTransformerBlock(nn.Module): ...@@ -154,7 +201,21 @@ class BasicTransformerBlock(nn.Module):
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): r"""
A cross attention layer.
Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim context_dim = context_dim if context_dim is not None else query_dim
...@@ -228,7 +289,20 @@ class CrossAttention(nn.Module): ...@@ -228,7 +289,20 @@ class CrossAttention(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): r"""
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim dim_out = dim_out if dim_out is not None else dim
...@@ -242,7 +316,15 @@ class FeedForward(nn.Module): ...@@ -242,7 +316,15 @@ class FeedForward(nn.Module):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = nn.Linear(dim_in, dim_out * 2)
......
...@@ -19,7 +19,12 @@ from torch import nn ...@@ -19,7 +19,12 @@ from torch import nn
def get_timestep_embedding( def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
): ):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
...@@ -55,7 +60,7 @@ def get_timestep_embedding( ...@@ -55,7 +60,7 @@ def get_timestep_embedding(
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"): def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim) self.linear_1 = nn.Linear(channel, time_embed_dim)
...@@ -75,7 +80,7 @@ class TimestepEmbedding(nn.Module): ...@@ -75,7 +80,7 @@ class TimestepEmbedding(nn.Module):
class Timesteps(nn.Module): class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__() super().__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos self.flip_sin_to_cos = flip_sin_to_cos
...@@ -94,7 +99,7 @@ class Timesteps(nn.Module): ...@@ -94,7 +99,7 @@ class Timesteps(nn.Module):
class GaussianFourierProjection(nn.Module): class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels.""" """Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0): def __init__(self, embedding_size: int = 256, scale: float = 1.0):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
......
...@@ -23,6 +23,38 @@ class UNet2DOutput(BaseOutput): ...@@ -23,6 +23,38 @@ class UNet2DOutput(BaseOutput):
class UNet2DModel(ModelMixin, ConfigMixin): class UNet2DModel(ModelMixin, ConfigMixin):
r"""
UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(224, 448, 672, 896)`): Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -136,6 +168,17 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -136,6 +168,17 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]: ) -> Union[UNet2DOutput, Tuple]:
"""r
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
......
...@@ -23,6 +23,37 @@ class UNet2DConditionOutput(BaseOutput): ...@@ -23,6 +23,37 @@ class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel(ModelMixin, ConfigMixin): class UNet2DConditionModel(ModelMixin, ConfigMixin):
r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -162,6 +193,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -162,6 +193,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
"""r
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
......
...@@ -371,6 +371,27 @@ class DiagonalGaussianDistribution(object): ...@@ -371,6 +371,27 @@ class DiagonalGaussianDistribution(object):
class VQModel(ModelMixin, ConfigMixin): class VQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
Kavukcuoglu.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -440,6 +461,12 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -440,6 +461,12 @@ class VQModel(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample x = sample
h = self.encode(x).latents h = self.encode(x).latents
dec = self.decode(h).sample dec = self.decode(h).sample
...@@ -451,6 +478,26 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -451,6 +478,26 @@ class VQModel(ModelMixin, ConfigMixin):
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -512,6 +559,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -512,6 +559,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def forward( def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample x = sample
posterior = self.encode(x).latent_dist posterior = self.encode(x).latent_dist
if sample_posterior: if sample_posterior:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment