Unverified Commit 174dcd69 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Model API (#3562)

* add modelmixin and unets

* remove old model page

* minor fixes

* fix unet2dcondition

* add vqmodel and autoencoderkl

* add rest of models

* fix autoencoderkl path

* fix toctree

* fix toctree again

* apply feedback

* apply feedback

* fix copies

* fix controlnet copy

* fix copies
parent cdf2ae8a
...@@ -132,8 +132,6 @@ ...@@ -132,8 +132,6 @@
title: Conceptual Guides title: Conceptual Guides
- sections: - sections:
- sections: - sections:
- local: api/models
title: Models
- local: api/attnprocessor - local: api/attnprocessor
title: Attention Processor title: Attention Processor
- local: api/diffusion_pipeline - local: api/diffusion_pipeline
...@@ -151,6 +149,30 @@ ...@@ -151,6 +149,30 @@
- local: api/image_processor - local: api/image_processor
title: VAE Image Processor title: VAE Image Processor
title: Main Classes title: Main Classes
- sections:
- local: api/models/overview
title: Overview
- local: api/models/unet
title: UNet1DModel
- local: api/models/unet2d
title: UNet2DModel
- local: api/models/unet2d-cond
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/vq
title: VQModel
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/transformer2d
title: Transformer2D
- local: api/models/transformer_temporal
title: Transformer Temporal
- local: api/models/prior_transformer
title: Prior Transformer
- local: api/models/controlnet
title: ControlNet
title: Models
- sections: - sections:
- local: api/pipelines/overview - local: api/pipelines/overview
title: Overview title: Overview
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Models
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\).
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
## ModelMixin
[[autodoc]] ModelMixin
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
## UNet2DModel
[[autodoc]] UNet2DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
## UNet3DConditionModel
[[autodoc]] UNet3DConditionModel
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
## VQEncoderOutput
[[autodoc]] models.vq_model.VQEncoderOutput
## VQModel
[[autodoc]] VQModel
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
## AutoencoderKL
[[autodoc]] AutoencoderKL
## Transformer2DModel
[[autodoc]] Transformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
## TransformerTemporalModel
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
## Transformer2DModelOutput
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
## PriorTransformer
[[autodoc]] models.prior_transformer.PriorTransformer
## PriorTransformerOutput
[[autodoc]] models.prior_transformer.PriorTransformerOutput
## ControlNetOutput
[[autodoc]] models.controlnet.ControlNetOutput
## ControlNetModel
[[autodoc]] ControlNetModel
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
# AutoencoderKL
The variational autoencoder (VAE) model with KL loss was introduced in [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114v11) by Diederik P. Kingma and Max Welling. The model is used in 🤗 Diffusers to encode images into latents and to decode latent representations into images.
The abstract from the paper is:
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
## AutoencoderKL
[[autodoc]] AutoencoderKL
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
\ No newline at end of file
# ControlNet
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
The abstract from the paper is:
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
## ControlNetModel
[[autodoc]] ControlNetModel
## ControlNetOutput
[[autodoc]] models.controlnet.ControlNetOutput
## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
\ No newline at end of file
# Models
🤗 Diffusers provides pretrained models for popular algorithms and modules to create custom diffusion systems. The primary function of models is to denoise an input sample as modeled by the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\).
All models are built from the base [`ModelMixin`] class which is a [`torch.nn.module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) providing basic functionality for saving and loading models, locally and from the Hugging Face Hub.
## ModelMixin
[[autodoc]] ModelMixin
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
\ No newline at end of file
# Prior Transformer
The Prior Transformer was originally introduced in [Hierarchical Text-Conditional Image Generation with CLIP Latents
](https://huggingface.co/papers/2204.06125) by Ramesh et al. It is used to predict CLIP image embeddings from CLIP text embeddings; image embeddings are predicted through a denoising diffusion process.
The abstract from the paper is:
*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*
## PriorTransformer
[[autodoc]] PriorTransformer
## PriorTransformerOutput
[[autodoc]] models.prior_transformer.PriorTransformerOutput
\ No newline at end of file
# Transformer2D
A Transformer model for image-like data from [CompVis](https://huggingface.co/CompVis) that is based on the [Vision Transformer](https://huggingface.co/papers/2010.11929) introduced by Dosovitskiy et al. The [`Transformer2DModel`] accepts discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
When the input is **continuous**:
1. Project the input and reshape it to `(batch_size, sequence_length, feature_dimension)`.
2. Apply the Transformer blocks in the standard way.
3. Reshape to image.
When the input is **discrete**:
<Tip>
It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked.
</Tip>
1. Convert input (classes of latent pixels) to embeddings and apply positional embeddings.
2. Apply the Transformer blocks in the standard way.
3. Predict classes of unnoised image.
## Transformer2DModel
[[autodoc]] Transformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
# Transformer Temporal
A Transformer model for video-like data.
## TransformerTemporalModel
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
## TransformerTemporalModelOutput
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
\ No newline at end of file
# UNet1DModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 1D UNet model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
\ No newline at end of file
# UNet2DConditionModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet conditional model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
\ No newline at end of file
# UNet2DModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet2DModel
[[autodoc]] UNet2DModel
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
\ No newline at end of file
# UNet3DConditionModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 3D UNet conditional model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet3DConditionModel
[[autodoc]] UNet3DConditionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
\ No newline at end of file
# VQModel
The VQ-VAE model was introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu. The model is used in 🤗 Diffusers to decode latent representations into images. Unlike [`AutoencoderKL`], the [`VQModel`] works in a quantized latent space.
The abstract from the paper is:
*Learning useful representations without supervision remains a key challenge in machine learning. In this paper, we propose a simple yet powerful generative model that learns such discrete representations. Our model, the Vector Quantised-Variational AutoEncoder (VQ-VAE), differs from VAEs in two key ways: the encoder network outputs discrete, rather than continuous, codes; and the prior is learnt rather than static. In order to learn a discrete latent representation, we incorporate ideas from vector quantisation (VQ). Using the VQ method allows the model to circumvent issues of "posterior collapse" -- where the latents are ignored when they are paired with a powerful autoregressive decoder -- typically observed in the VAE framework. Pairing these representations with an autoregressive prior, the model can generate high quality images, videos, and speech as well as doing high quality speaker conversion and unsupervised learning of phonemes, providing further evidence of the utility of the learnt representations.*
## VQModel
[[autodoc]] VQModel
## VQEncoderOutput
[[autodoc]] models.vq_model.VQEncoderOutput
\ No newline at end of file
...@@ -39,24 +39,24 @@ class AutoencoderKLOutput(BaseOutput): ...@@ -39,24 +39,24 @@ class AutoencoderKLOutput(BaseOutput):
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma r"""
and Max Welling. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
implements for all the model (such as downloading or saving, etc.) for all models (such as downloading or saving).
Parameters: Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image. in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output. out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to : down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to : up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to : block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
obj:`(64,)`): Tuple of block output channels. Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215): scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion training set. This is used to scale the latent space to have unit variance when training the diffusion
...@@ -131,15 +131,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -131,15 +131,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def enable_tiling(self, use_tiling: bool = True): def enable_tiling(self, use_tiling: bool = True):
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
the processing of larger images. processing larger images.
""" """
self.use_tiling = use_tiling self.use_tiling = use_tiling
def disable_tiling(self): def disable_tiling(self):
r""" r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
computing decoding in one step. decoding in one step.
""" """
self.enable_tiling(False) self.enable_tiling(False)
...@@ -152,7 +152,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -152,7 +152,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def disable_slicing(self): def disable_slicing(self):
r""" r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step. decoding in one step.
""" """
self.use_slicing = False self.use_slicing = False
...@@ -185,11 +185,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -185,11 +185,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention.
Parameters: Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. for **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -274,14 +278,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -274,14 +278,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable. output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
""" """
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
...@@ -319,17 +330,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -319,17 +330,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder. r"""
Decode a batch of images using a tiled decoder.
Args: Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several z (`torch.FloatTensor`): Input batch of latent vectors.
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: return_dict (`bool`, *optional*, defaults to `True`):
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable. Returns:
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to [`~models.vae.DecoderOutput`] or `tuple`:
`True`): If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. returned.
""" """
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
......
...@@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass @dataclass
class ControlNetOutput(BaseOutput): class ControlNetOutput(BaseOutput):
"""
The output of [`ControlNetModel`].
Args:
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
down_block_res_samples: Tuple[torch.Tensor] down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor mid_block_res_sample: torch.Tensor
...@@ -87,6 +101,58 @@ class ControlNetConditioningEmbedding(nn.Module): ...@@ -87,6 +101,58 @@ class ControlNetConditioningEmbedding(nn.Module):
class ControlNetModel(ModelMixin, ConfigMixin): class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
"""
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
...@@ -283,12 +349,12 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -283,12 +349,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
load_weights_from_unet: bool = True, load_weights_from_unet: bool = True,
): ):
r""" r"""
Instantiate Controlnet class from UNet2DConditionModel. Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters: Parameters:
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
UNet model which weights are copied to the ControlNet. Note that all configuration options are also The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
copied where applicable. where applicable.
""" """
controlnet = cls( controlnet = cls(
in_channels=unet.config.in_channels, in_channels=unet.config.in_channels,
...@@ -357,11 +423,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -357,11 +423,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention.
Parameters: Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. for **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -397,15 +467,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -397,15 +467,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention When this option is enabled, the attention module splits the input tensor in slices to compute attention in
in several steps. This is useful to save some memory in exchange for a small speed decrease. several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `num_attention_heads // slice_size`. In this case, provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
`num_attention_heads` must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
...@@ -476,6 +546,37 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -476,6 +546,37 @@ class ControlNetModel(ModelMixin, ConfigMixin):
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]: ) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order # check channel order
channel_order = self.config.controlnet_conditioning_channel_order channel_order = self.config.controlnet_conditioning_channel_order
......
...@@ -32,6 +32,14 @@ from .unet_2d_blocks_flax import ( ...@@ -32,6 +32,14 @@ from .unet_2d_blocks_flax import (
@flax.struct.dataclass @flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput): class FlaxControlNetOutput(BaseOutput):
"""
The output of [`FlaxControlNetModel`].
Args:
down_block_res_samples (`jnp.ndarray`):
mid_block_res_sample (`jnp.ndarray`):
"""
down_block_res_samples: jnp.ndarray down_block_res_samples: jnp.ndarray
mid_block_res_sample: jnp.ndarray mid_block_res_sample: jnp.ndarray
...@@ -95,21 +103,17 @@ class FlaxControlNetConditioningEmbedding(nn.Module): ...@@ -95,21 +103,17 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
@flax_register_to_config @flax_register_to_config
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
r""" r"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN A ControlNet model.
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides implemented for all models (such as downloading or saving).
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..." This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior. general usage and behavior.
Finally, this model supports inherent JAX features such as: Inherent JAX features such as the following are supported:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
...@@ -120,9 +124,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -120,9 +124,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
The size of the input sample. The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample. The number of channels in the input sample.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", The tuple of downsample blocks to use.
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
...@@ -139,11 +142,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -139,11 +142,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
The channel order of conditional image. Will convert it to `rgb` if it's `bgr` The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in conditioning_embedding layer The tuple of output channel for each block in the `conditioning_embedding` layer.
""" """
sample_size: int = 32 sample_size: int = 32
in_channels: int = 4 in_channels: int = 4
......
...@@ -44,10 +44,12 @@ logger = logging.get_logger(__name__) ...@@ -44,10 +44,12 @@ logger = logging.get_logger(__name__)
class FlaxModelMixin: class FlaxModelMixin:
r""" r"""
Base class for all flax models. Base class for all Flax models.
[`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
downloading and saving models. saving models.
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
...@@ -89,15 +91,15 @@ class FlaxModelMixin: ...@@ -89,15 +91,15 @@ class FlaxModelMixin:
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place. the `params` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments: Arguments:
params (`Union[Dict, FrozenDict]`): params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters. A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`): mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
you want to cast, and should be `False` for those you want to skip. for params you want to cast, and `False` for those you want to skip.
Examples: Examples:
...@@ -132,8 +134,8 @@ class FlaxModelMixin: ...@@ -132,8 +134,8 @@ class FlaxModelMixin:
params (`Union[Dict, FrozenDict]`): params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters. A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`): mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
you want to cast, and should be `False` for those you want to skip for params you want to cast, and `False` for those you want to skip.
Examples: Examples:
...@@ -155,15 +157,15 @@ class FlaxModelMixin: ...@@ -155,15 +157,15 @@ class FlaxModelMixin:
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
`params` in place. `params` in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed. half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments: Arguments:
params (`Union[Dict, FrozenDict]`): params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters. A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`): mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
you want to cast, and should be `False` for those you want to skip for params you want to cast, and `False` for those you want to skip.
Examples: Examples:
...@@ -201,71 +203,68 @@ class FlaxModelMixin: ...@@ -201,71 +203,68 @@ class FlaxModelMixin:
**kwargs, **kwargs,
): ):
r""" r"""
Instantiate a pretrained flax model from a pre-trained model configuration. Instantiate a pretrained Flax model from a pretrained model configuration.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either: Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
Valid model ids are namespaced under a user or organization name, like hosted on the Hub.
`runwayml/stable-diffusion-v1-5`. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], using [`~FlaxModelMixin.save_pretrained`].
e.g., `./my_model_directory/`.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs). `jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`. specified, all the computation will be performed with the given `dtype`.
<Tip>
This only specifies the dtype of the *computation* and does not influence the dtype of model
parameters.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
parameters.** [`~FlaxModelMixin.to_bf16`].
</Tip>
If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
[`~ModelMixin.to_bf16`].
model_args (sequence of positional arguments, *optional*): model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method. All remaining positional arguments are passed to the underlying model's `__init__` method.
cache_dir (`Union[str, os.PathLike]`, *optional*): cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
standard cache should not be used. is not used.
force_download (`bool`, *optional*, defaults to `False`): force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist. cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`): resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
file exists. incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any allowed by Git.
identifier allowed by git.
from_pt (`bool`, *optional*, defaults to `False`): from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file. Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it is loaded) and initiate the model (for
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded: automatically loaded:
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
underlying model's `__init__` method (we assume all relevant updates to the configuration have model's `__init__` method (we assume all relevant updates to the configuration have already been
already been done) done).
- If a configuration is not provided, `kwargs` will be first passed to the configuration class - If a configuration is not provided, `kwargs` are first passed to the configuration class
initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
a configuration attribute will be used to override said attribute with the supplied `kwargs` to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
value. Remaining keys that do not correspond to any configuration attribute will be passed to the Remaining keys that do not correspond to any configuration attribute are passed to the underlying
underlying model's `__init__` function. model's `__init__` function.
Examples: Examples:
...@@ -276,7 +275,16 @@ class FlaxModelMixin: ...@@ -276,7 +275,16 @@ class FlaxModelMixin:
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
```""" ```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -491,18 +499,18 @@ class FlaxModelMixin: ...@@ -491,18 +499,18 @@ class FlaxModelMixin:
is_main_process: bool = True, is_main_process: bool = True,
): ):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory so that it can be reloaded using the
`[`~FlaxModelMixin.from_pretrained`]` class method [`~FlaxModelMixin.from_pretrained`] class method.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to save a model and its configuration file to. Will be created if it doesn't exist.
params (`Union[Dict, FrozenDict]`): params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters. A `PyTree` of model parameters.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like Whether the process calling this is the main process or not. Useful during distributed training and you
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on need to call this function on all processes. In this case, set `is_main_process=True` only on the main
the main process to avoid race conditions. process to avoid race conditions.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
......
This diff is collapsed.
...@@ -16,6 +16,8 @@ from .modeling_utils import ModelMixin ...@@ -16,6 +16,8 @@ from .modeling_utils import ModelMixin
@dataclass @dataclass
class PriorTransformerOutput(BaseOutput): class PriorTransformerOutput(BaseOutput):
""" """
The output of [`PriorTransformer`].
Args: Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input. The predicted CLIP image embedding conditioned on the CLIP text embedding input.
...@@ -26,27 +28,20 @@ class PriorTransformerOutput(BaseOutput): ...@@ -26,27 +28,20 @@ class PriorTransformerOutput(BaseOutput):
class PriorTransformer(ModelMixin, ConfigMixin): class PriorTransformer(ModelMixin, ConfigMixin):
""" """
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the A Prior Transformer model.
transformer predicts the image embeddings through a denoising diffusion process.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
For more details, see the original paper: https://arxiv.org/abs/2204.06125
Parameters: Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP embedding_dim (`int`, *optional*, defaults to 768):
image embeddings and text embeddings are both the same dimension. The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the
length of the prompt after it has been tokenized. length of the prompt after it has been tokenized).
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected hidden_states. The actual length of the used hidden_states is `num_embeddings + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`. additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
""" """
@register_to_config @register_to_config
...@@ -133,11 +128,15 @@ class PriorTransformer(ModelMixin, ConfigMixin): ...@@ -133,11 +128,15 @@ class PriorTransformer(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention.
Parameters: Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. for **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -178,10 +177,12 @@ class PriorTransformer(ModelMixin, ConfigMixin): ...@@ -178,10 +177,12 @@ class PriorTransformer(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
): ):
""" """
The [`PriorTransformer`] forward method.
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
x_t, the currently predicted image embeddings. The currently predicted image embeddings.
timestep (`torch.long`): timestep (`torch.LongTensor`):
Current denoising step. Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on. Projected embedding vector the denoising process is conditioned on.
...@@ -190,13 +191,13 @@ class PriorTransformer(ModelMixin, ConfigMixin): ...@@ -190,13 +191,13 @@ class PriorTransformer(ModelMixin, ConfigMixin):
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings. Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple. tuple.
Returns: Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
returning a tuple, the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
......
...@@ -29,10 +29,12 @@ from .modeling_utils import ModelMixin ...@@ -29,10 +29,12 @@ from .modeling_utils import ModelMixin
@dataclass @dataclass
class Transformer2DModelOutput(BaseOutput): class Transformer2DModelOutput(BaseOutput):
""" """
The output of [`Transformer2DModel`].
Args: Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
for the unnoised latent pixels. distributions for the unnoised latent pixels.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor
...@@ -40,40 +42,30 @@ class Transformer2DModelOutput(BaseOutput): ...@@ -40,40 +42,30 @@ class Transformer2DModelOutput(BaseOutput):
class Transformer2DModel(ModelMixin, ConfigMixin): class Transformer2DModel(ModelMixin, ConfigMixin):
""" """
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual A 2D Transformer model for image-like data.
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters: Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*): in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output. The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
Note that this is fixed at training time as it is used for learning a number of position embeddings. See This is fixed during training since it is used to learn a number of position embeddings.
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*): num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel. Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Note that this is fixed at training time as it is used The number of diffusion steps used during training. Pass if at least one of the norm_layers is
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
up to but not more than steps than `num_embeds_ada_norm`. added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*): attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter. Configure if the `TransformerBlocks` attention should contain a bias parameter.
""" """
@register_to_config @register_to_config
...@@ -223,31 +215,34 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -223,31 +215,34 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
): ):
""" """
The [`Transformer2DModel`] forward method.
Args: Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input Input `hidden_states`.
hidden_states
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.LongTensor`, *optional*): timestep ( `torch.LongTensor`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
conditioning. `AdaLayerZeroNorm`.
encoder_attention_mask ( `torch.Tensor`, *optional* ). encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask, applied to encoder_hidden_states. Two formats supported: Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
= keep, -10000 = discard. * Mask `(batch, sequence_length)` True = keep, False = discard.
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores. above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns: Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When `tuple` where the first element is the sample tensor.
returning a tuple, the first element is the sample tensor.
""" """
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment