"examples/vscode:/vscode.git/clone" did not exist on "0cc48011aba1de3fc926019489a831bed81373a1"
Unverified Commit 174dcd69 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Model API (#3562)

* add modelmixin and unets

* remove old model page

* minor fixes

* fix unet2dcondition

* add vqmodel and autoencoderkl

* add rest of models

* fix autoencoderkl path

* fix toctree

* fix toctree again

* apply feedback

* apply feedback

* fix copies

* fix controlnet copy

* fix copies
parent cdf2ae8a
......@@ -132,8 +132,6 @@
title: Conceptual Guides
- sections:
- sections:
- local: api/models
title: Models
- local: api/attnprocessor
title: Attention Processor
- local: api/diffusion_pipeline
......@@ -151,6 +149,30 @@
- local: api/image_processor
title: VAE Image Processor
title: Main Classes
- sections:
- local: api/models/overview
title: Overview
- local: api/models/unet
title: UNet1DModel
- local: api/models/unet2d
title: UNet2DModel
- local: api/models/unet2d-cond
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/vq
title: VQModel
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/transformer2d
title: Transformer2D
- local: api/models/transformer_temporal
title: Transformer Temporal
- local: api/models/prior_transformer
title: Prior Transformer
- local: api/models/controlnet
title: ControlNet
title: Models
- sections:
- local: api/pipelines/overview
title: Overview
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Models
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\).
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
## ModelMixin
[[autodoc]] ModelMixin
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
## UNet2DModel
[[autodoc]] UNet2DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
## UNet3DConditionModel
[[autodoc]] UNet3DConditionModel
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
## VQEncoderOutput
[[autodoc]] models.vq_model.VQEncoderOutput
## VQModel
[[autodoc]] VQModel
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
## AutoencoderKL
[[autodoc]] AutoencoderKL
## Transformer2DModel
[[autodoc]] Transformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
## TransformerTemporalModel
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
## Transformer2DModelOutput
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
## PriorTransformer
[[autodoc]] models.prior_transformer.PriorTransformer
## PriorTransformerOutput
[[autodoc]] models.prior_transformer.PriorTransformerOutput
## ControlNetOutput
[[autodoc]] models.controlnet.ControlNetOutput
## ControlNetModel
[[autodoc]] ControlNetModel
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
# AutoencoderKL
The variational autoencoder (VAE) model with KL loss was introduced in [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114v11) by Diederik P. Kingma and Max Welling. The model is used in 🤗 Diffusers to encode images into latents and to decode latent representations into images.
The abstract from the paper is:
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
## AutoencoderKL
[[autodoc]] AutoencoderKL
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
\ No newline at end of file
# ControlNet
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
The abstract from the paper is:
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
## ControlNetModel
[[autodoc]] ControlNetModel
## ControlNetOutput
[[autodoc]] models.controlnet.ControlNetOutput
## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
\ No newline at end of file
# Models
🤗 Diffusers provides pretrained models for popular algorithms and modules to create custom diffusion systems. The primary function of models is to denoise an input sample as modeled by the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\).
All models are built from the base [`ModelMixin`] class which is a [`torch.nn.module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) providing basic functionality for saving and loading models, locally and from the Hugging Face Hub.
## ModelMixin
[[autodoc]] ModelMixin
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
\ No newline at end of file
# Prior Transformer
The Prior Transformer was originally introduced in [Hierarchical Text-Conditional Image Generation with CLIP Latents
](https://huggingface.co/papers/2204.06125) by Ramesh et al. It is used to predict CLIP image embeddings from CLIP text embeddings; image embeddings are predicted through a denoising diffusion process.
The abstract from the paper is:
*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*
## PriorTransformer
[[autodoc]] PriorTransformer
## PriorTransformerOutput
[[autodoc]] models.prior_transformer.PriorTransformerOutput
\ No newline at end of file
# Transformer2D
A Transformer model for image-like data from [CompVis](https://huggingface.co/CompVis) that is based on the [Vision Transformer](https://huggingface.co/papers/2010.11929) introduced by Dosovitskiy et al. The [`Transformer2DModel`] accepts discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
When the input is **continuous**:
1. Project the input and reshape it to `(batch_size, sequence_length, feature_dimension)`.
2. Apply the Transformer blocks in the standard way.
3. Reshape to image.
When the input is **discrete**:
<Tip>
It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked.
</Tip>
1. Convert input (classes of latent pixels) to embeddings and apply positional embeddings.
2. Apply the Transformer blocks in the standard way.
3. Predict classes of unnoised image.
## Transformer2DModel
[[autodoc]] Transformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
# Transformer Temporal
A Transformer model for video-like data.
## TransformerTemporalModel
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
## TransformerTemporalModelOutput
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
\ No newline at end of file
# UNet1DModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 1D UNet model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
\ No newline at end of file
# UNet2DConditionModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet conditional model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
\ No newline at end of file
# UNet2DModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet2DModel
[[autodoc]] UNet2DModel
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
\ No newline at end of file
# UNet3DConditionModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 3D UNet conditional model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNet3DConditionModel
[[autodoc]] UNet3DConditionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
\ No newline at end of file
# VQModel
The VQ-VAE model was introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu. The model is used in 🤗 Diffusers to decode latent representations into images. Unlike [`AutoencoderKL`], the [`VQModel`] works in a quantized latent space.
The abstract from the paper is:
*Learning useful representations without supervision remains a key challenge in machine learning. In this paper, we propose a simple yet powerful generative model that learns such discrete representations. Our model, the Vector Quantised-Variational AutoEncoder (VQ-VAE), differs from VAEs in two key ways: the encoder network outputs discrete, rather than continuous, codes; and the prior is learnt rather than static. In order to learn a discrete latent representation, we incorporate ideas from vector quantisation (VQ). Using the VQ method allows the model to circumvent issues of "posterior collapse" -- where the latents are ignored when they are paired with a powerful autoregressive decoder -- typically observed in the VAE framework. Pairing these representations with an autoregressive prior, the model can generate high quality images, videos, and speech as well as doing high quality speaker conversion and unsupervised learning of phonemes, providing further evidence of the utility of the learnt representations.*
## VQModel
[[autodoc]] VQModel
## VQEncoderOutput
[[autodoc]] models.vq_model.VQEncoderOutput
\ No newline at end of file
......@@ -39,24 +39,24 @@ class AutoencoderKLOutput(BaseOutput):
class AutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
......@@ -131,15 +131,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
the processing of larger images.
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
computing decoding in one step.
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
......@@ -152,7 +152,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
......@@ -185,11 +185,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
......@@ -274,14 +278,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
......@@ -319,17 +330,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
r"""
Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
`True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
......
......@@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetOutput(BaseOutput):
"""
The output of [`ControlNetModel`].
Args:
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
......@@ -87,6 +101,58 @@ class ControlNetConditioningEmbedding(nn.Module):
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
"""
_supports_gradient_checkpointing = True
@register_to_config
......@@ -283,12 +349,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
load_weights_from_unet: bool = True,
):
r"""
Instantiate Controlnet class from UNet2DConditionModel.
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
UNet model which weights are copied to the ControlNet. Note that all configuration options are also
copied where applicable.
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
......@@ -357,11 +423,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
......@@ -397,15 +467,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
`num_attention_heads` must be a multiple of `slice_size`.
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
......@@ -476,6 +546,37 @@ class ControlNetModel(ModelMixin, ConfigMixin):
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
......
......@@ -32,6 +32,14 @@ from .unet_2d_blocks_flax import (
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
The output of [`FlaxControlNetModel`].
Args:
down_block_res_samples (`jnp.ndarray`):
mid_block_res_sample (`jnp.ndarray`):
"""
down_block_res_samples: jnp.ndarray
mid_block_res_sample: jnp.ndarray
......@@ -95,21 +103,17 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
@flax_register_to_config
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
A ControlNet model.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
implemented for all models (such as downloading or saving).
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
general usage and behavior.
Finally, this model supports inherent JAX features such as:
Inherent JAX features such as the following are supported:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
......@@ -120,9 +124,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
......@@ -139,11 +142,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
The channel order of conditional image. Will convert it to `rgb` if it's `bgr`
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in conditioning_embedding layer
The tuple of output channel for each block in the `conditioning_embedding` layer.
"""
sample_size: int = 32
in_channels: int = 4
......
......@@ -44,10 +44,12 @@ logger = logging.get_logger(__name__)
class FlaxModelMixin:
r"""
Base class for all flax models.
Base class for all Flax models.
[`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models.
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
saving models.
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
......@@ -89,15 +91,15 @@ class FlaxModelMixin:
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip.
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
......@@ -132,8 +134,8 @@ class FlaxModelMixin:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
......@@ -155,15 +157,15 @@ class FlaxModelMixin:
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
`params` in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
......@@ -201,71 +203,68 @@ class FlaxModelMixin:
**kwargs,
):
r"""
Instantiate a pretrained flax model from a pre-trained model configuration.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Instantiate a pretrained Flax model from a pretrained model configuration.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids are namespaced under a user or organization name, like
`runwayml/stable-diffusion-v1-5`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxModelMixin.save_pretrained`].
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
specified, all the computation will be performed with the given `dtype`.
<Tip>
This only specifies the dtype of the *computation* and does not influence the dtype of model
parameters.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
[`~FlaxModelMixin.to_bf16`].
</Tip>
If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
[`~ModelMixin.to_bf16`].
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
All remaining positional arguments are passed to the underlying model's `__init__` method.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
Can be used to update the configuration object (after it is loaded) and initiate the model (for
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to
a configuration attribute will be used to override said attribute with the supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
underlying model's `__init__` function.
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
model's `__init__` method (we assume all relevant updates to the configuration have already been
done).
- If a configuration is not provided, `kwargs` are first passed to the configuration class
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
Remaining keys that do not correspond to any configuration attribute are passed to the underlying
model's `__init__` function.
Examples:
......@@ -276,7 +275,16 @@ class FlaxModelMixin:
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
```"""
```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
......@@ -491,18 +499,18 @@ class FlaxModelMixin:
is_main_process: bool = True,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~FlaxModelMixin.from_pretrained`]` class method
Save a model and its configuration file to a directory so that it can be reloaded using the
[`~FlaxModelMixin.from_pretrained`] class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
......
......@@ -154,11 +154,10 @@ class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models.
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
saving models.
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~models.ModelMixin.save_pretrained`].
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
......@@ -190,18 +189,13 @@ class ModelMixin(torch.nn.Module):
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
"""
if not self._supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
......@@ -209,10 +203,8 @@ class ModelMixin(torch.nn.Module):
def disable_gradient_checkpointing(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
"""
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
......@@ -236,13 +228,17 @@ class ModelMixin(torch.nn.Module):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
inference. Speed up during training is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
<Tip warning={true}>
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
precedent.
</Tip>
Parameters:
attention_op (`Callable`, *optional*):
......@@ -268,7 +264,7 @@ class ModelMixin(torch.nn.Module):
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
"""
self.set_use_memory_efficient_attention_xformers(False)
......@@ -281,24 +277,24 @@ class ModelMixin(torch.nn.Module):
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~models.ModelMixin.from_pretrained`]` class method.
Save a model and its configuration file to a directory so that it can be reloaded using the
[`~models.ModelMixin.from_pretrained`] class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
"""
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
......@@ -335,107 +331,108 @@ class ModelMixin(torch.nn.Module):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
Instantiate a pretrained PyTorch model from a pretrained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you should first set it back in training mode with `model.train()`.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
train the model, set it back in training mode with `model.train()`.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
`./my_model_directory/`.
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`~ModelMixin.save_pretrained`].
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
from_flax (`bool`, *optional*, defaults to `False`):
Load the model weights from a Flax checkpoint save file.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
The path to offload weights if `device_map` contains the value `"disk"`.
offload_state_dict (`bool`, *optional*):
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
when there is some disk offload.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
`huggingface-cli login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
</Tip>
<Tip>
Example:
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
```py
from diffusers import UNet2DConditionModel
</Tip>
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
......@@ -852,17 +849,27 @@ class ModelMixin(torch.nn.Module):
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Get number of (trainable or non-embedding) parameters in the module.
Args:
only_trainable (`bool`, *optional*, defaults to `False`):
Whether or not to return only the number of trainable parameters
Whether or not to return only the number of trainable parameters.
exclude_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to return only the number of non-embeddings parameters
Whether or not to return only the number of non-embedding parameters.
Returns:
`int`: The number of parameters.
Example:
```py
from diffusers import UNet2DConditionModel
model_id = "runwayml/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
unet.num_parameters(only_trainable=True)
859520964
```
"""
if exclude_embeddings:
......
......@@ -16,6 +16,8 @@ from .modeling_utils import ModelMixin
@dataclass
class PriorTransformerOutput(BaseOutput):
"""
The output of [`PriorTransformer`].
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
......@@ -26,27 +28,20 @@ class PriorTransformerOutput(BaseOutput):
class PriorTransformer(ModelMixin, ConfigMixin):
"""
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
transformer predicts the image embeddings through a denoising diffusion process.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
For more details, see the original paper: https://arxiv.org/abs/2204.06125
A Prior Transformer model.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
length of the prompt after it has been tokenized.
embedding_dim (`int`, *optional*, defaults to 768):
The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the
length of the prompt after it has been tokenized).
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
@register_to_config
......@@ -133,11 +128,15 @@ class PriorTransformer(ModelMixin, ConfigMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
......@@ -178,10 +177,12 @@ class PriorTransformer(ModelMixin, ConfigMixin):
return_dict: bool = True,
):
"""
The [`PriorTransformer`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
x_t, the currently predicted image embeddings.
timestep (`torch.long`):
The currently predicted image embeddings.
timestep (`torch.LongTensor`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
......@@ -190,13 +191,13 @@ class PriorTransformer(ModelMixin, ConfigMixin):
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple.
Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
batch_size = hidden_states.shape[0]
......
......@@ -29,10 +29,12 @@ from .modeling_utils import ModelMixin
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
for the unnoised latent pixels.
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
......@@ -40,40 +42,30 @@ class Transformer2DModelOutput(BaseOutput):
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
@register_to_config
......@@ -223,31 +215,34 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
encoder_attention_mask ( `torch.Tensor`, *optional* ).
Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
= keep, -10000 = discard.
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment