"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "ad3b09f189995fa6b84a57df1a4569372d9e1147"
Unverified Commit 63f767ef authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add SVD (#5895)



* begin model

* finish blocks

* add_embedding

* addition_time_embed_dim

* use TimestepEmbedding

* fix temporal res block

* fix time_pos_embed

* fix add_embedding

* add conversion script

* fix model

* up

* add new resnet blocks

* make forward work

* return sample in original shape

* fix temb shape in TemporalResnetBlock

* add spatio temporal transformers

* add vae blocks

* fix blocks

* update

* update

* fix shapes in Alphablender and add time activation in res blcok

* use new blocks

* style

* fix temb shape

* fix SpatioTemporalResBlock

* reuse TemporalBasicTransformerBlock

* fix TemporalBasicTransformerBlock

* use TransformerSpatioTemporalModel

* fix TransformerSpatioTemporalModel

* fix time_context dim

* clean up

* make temb optional

* add blocks

* rename model

* update conversion script

* remove UNetMidBlockSpatioTemporal

* add in init

* remove unused arg

* remove unused arg

* remove more unsed args

* up

* up

* check for None

* update vae

* update up/mid blocks for decoder

* begin pipeline

* adapt scheduler

* add guidance scalings

* fix norm eps in temporal transformers

* add temporal autoencoder

* make pipeline run

* fix frame decodig

* decode in float32

* decode n frames at a time

* pass decoding_t to decode_latents

* fix decode_latents

* vae encode/decode in fp32

* fix dtype in TransformerSpatioTemporalModel

* type image_latents same as image_embeddings

* allow using differnt eps in temporal block for video decoder

* fix default values in vae

* pass num frames in decode

* switch spatial to temporal for mixing in VAE

* fix num frames during split decoding

* cast alpha to sample dtype

* fix attention in MidBlockTemporalDecoder

* fix typo

* fix guidance_scales dtype

* fix missing activation in TemporalDecoder

* skip_post_quant_conv

* add vae conversion

* style

* take guidance scale as input

* up

* allow passing PIL to export_video

* accept fps as arg

* add pipeline and vae in init

* remove hack

* use AutoencoderKLTemporalDecoder

* don't scale image latents

* add unet tests

* clean up unet

* clean TransformerSpatioTemporalModel

* add slow svd test

* clean up

* make temb optional in Decoder mid block

* fix norm eps in TransformerSpatioTemporalModel

* clean up temp decoder

* clean up

* clean up

* use c_noise values for timesteps

* use math for log

* update

* fix copies

* doc

* upcast vae

* update forward pass for gradient checkpointing

* make added_time_ids is tensor

* up

* fix upcasting

* remove post quant conv

* add _resize_with_antialiasing

* fix _compute_padding

* cleanup model

* more cleanup

* more cleanup

* more cleanup

* remove freeu

* remove attn slice

* small clean

* up

* up

* remove extra step kwargs

* remove eta

* remove dropout

* remove callback

* remove merge factor args

* clean

* clean up

* move to dedicated folder

* remove attention_head_dim

* docstr and small fix

* update unet doc strings

* rename decoding_t

* correct linting

* store c_skip and c_out

* cleanup

* clean TemporalResnetBlock

* more cleanup

* clean up vae

* clean up

* begin doc

* more cleanup

* up

* up

* doc

* Improve

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* Apply suggestions from code review

* Default chunk size to None

* add example

* Better

* Apply suggestions from code review

* update doc

* Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable_diffusion_video.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* style

* Get torch compile working

* up

* rename

* fix doc

* add chunking

* torch compile

* torch compile

* add modelling outputs

* torch compile

* Improve chunking

* Apply suggestions from code review

* Update docs/source/en/using-diffusers/svd.md

* Close diff tag

* remove slicing

* resnet docstr

* add docstr in resnet

* rename

* Apply suggestions from code review

* update tests

* Fix output type latents

* fix more

* fix more

* Update docs/source/en/using-diffusers/svd.md

* fix more

* add pipeline tests

* remove unused arg

* clean  up

* make sure get_scaling receives tensors

* fix euler scheduler

* fix get_scalings

* simply euler for now

* remove old test file

* use randn_tensor to create noise

* fix device for rand tensor

* increase expected_max_difference

* fix test_inference_batch_single_identical

* actually fix test_inference_batch_single_identical

* disable test_save_load_float16

* skip test_float16_inference

* skip test_inference_batch_single_identical

* fix test_xformers_attention_forwardGenerator_pass

* Apply suggestions from code review

* update StableVideoDiffusionPipelineSlowTests

* update image

* add diffusers example

* fix more

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>
parent d1b2a1a9
...@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py ...@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
The following design principles are followed: The following design principles are followed:
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context. - Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc... - All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy. - Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages. - Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
- Models all inherit from `ModelMixin` and `ConfigMixin`. - Models all inherit from `ModelMixin` and `ConfigMixin`.
- Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain. - Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain.
......
...@@ -94,6 +94,8 @@ ...@@ -94,6 +94,8 @@
title: Latent Consistency Model-LoRA title: Latent Consistency Model-LoRA
- local: using-diffusers/inference_with_lcm - local: using-diffusers/inference_with_lcm
title: Latent Consistency Model title: Latent Consistency Model
- local: using-diffusers/svd
title: Stable Video Diffusion
title: Specific pipeline examples title: Specific pipeline examples
- sections: - sections:
- local: training/overview - local: training/overview
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Video Diffusion
[[open-in-colab]]
[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
This guide will show you how to use SVD to short generate videos from images.
Before you begin, make sure you have the following libraries installed:
```py
!pip install -q -U diffusers transformers accelerate
```
## Image to Video Generation
The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
finetuned to generate 25 frames.
We will use the `svd-xt` checkpoint for this guide.
```python
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()
# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video width="1024" height="576" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4?download=true" type="video/mp4">
</video>
<Tip>
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
</Tip>
### Torch.compile
You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
```diff
- pipe.enable_model_cpu_offload()
+ pipe.to("cuda")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
### Low-memory
Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
You can enable them as follows:
```diff
-pipe.enable_model_cpu_offload()
-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
+pipe.enable_model_cpu_offload()
+pipe.unet.enable_forward_chunking()
+frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
```
Including all these tricks should lower the memory requirement to less than 8GB VRAM.
### Micro-conditioning
Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
It accepts the following arguments:
- `fps`: The frames per second of the generated video.
- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
Here is an example of using micro-conditioning to generate a video with more motion.
```python
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()
# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video width="1024" height="576" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4?download=true" type="video/mp4">
</video>
This diff is collapsed.
...@@ -76,6 +76,7 @@ else: ...@@ -76,6 +76,7 @@ else:
[ [
"AsymmetricAutoencoderKL", "AsymmetricAutoencoderKL",
"AutoencoderKL", "AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny", "AutoencoderTiny",
"ConsistencyDecoderVAE", "ConsistencyDecoderVAE",
"ControlNetModel", "ControlNetModel",
...@@ -92,6 +93,7 @@ else: ...@@ -92,6 +93,7 @@ else:
"UNet2DModel", "UNet2DModel",
"UNet3DConditionModel", "UNet3DConditionModel",
"UNetMotionModel", "UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"VQModel", "VQModel",
] ]
) )
...@@ -277,6 +279,7 @@ else: ...@@ -277,6 +279,7 @@ else:
"StableDiffusionXLPipeline", "StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline", "StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline", "StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
"TextToVideoSDPipeline", "TextToVideoSDPipeline",
"TextToVideoZeroPipeline", "TextToVideoZeroPipeline",
"TextToVideoZeroSDXLPipeline", "TextToVideoZeroSDXLPipeline",
...@@ -447,6 +450,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -447,6 +450,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import ( from .models import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AutoencoderKL, AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny, AutoencoderTiny,
ConsistencyDecoderVAE, ConsistencyDecoderVAE,
ControlNetModel, ControlNetModel,
...@@ -463,6 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -463,6 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DModel, UNet2DModel,
UNet3DConditionModel, UNet3DConditionModel,
UNetMotionModel, UNetMotionModel,
UNetSpatioTemporalConditionModel,
VQModel, VQModel,
) )
from .optimization import ( from .optimization import (
...@@ -627,6 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -627,6 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
TextToVideoSDPipeline, TextToVideoSDPipeline,
TextToVideoZeroPipeline, TextToVideoZeroPipeline,
TextToVideoZeroSDXLPipeline, TextToVideoZeroSDXLPipeline,
......
...@@ -14,7 +14,12 @@ ...@@ -14,7 +14,12 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)
_import_structure = {} _import_structure = {}
...@@ -23,6 +28,7 @@ if is_torch_available(): ...@@ -23,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet"] = ["ControlNetModel"]
...@@ -38,6 +44,7 @@ if is_torch_available(): ...@@ -38,6 +44,7 @@ if is_torch_available():
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"] _import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"] _import_structure["vq_model"] = ["VQModel"]
if is_flax_available(): if is_flax_available():
...@@ -51,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -51,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel from .controlnet import ControlNetModel
...@@ -66,6 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -66,6 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_3d_condition import UNet3DConditionModel from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel from .vq_model import VQModel
if is_flax_available(): if is_flax_available():
......
...@@ -25,6 +25,31 @@ from .lora import LoRACompatibleLinear ...@@ -25,6 +25,31 @@ from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero from .normalization import AdaLayerNorm, AdaLayerNormZero
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph @maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module): class GatedSelfAttentionDense(nn.Module):
r""" r"""
...@@ -194,7 +219,12 @@ class BasicTransformerBlock(nn.Module): ...@@ -194,7 +219,12 @@ class BasicTransformerBlock(nn.Module):
if not self.use_ada_layer_norm_single: if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser # 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image": if attention_type == "gated" or attention_type == "gated-text-image":
...@@ -208,7 +238,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -208,7 +238,7 @@ class BasicTransformerBlock(nn.Module):
self._chunk_size = None self._chunk_size = None
self._chunk_dim = 0 self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward # Sets chunk feed-forward
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_dim = dim self._chunk_dim = dim
...@@ -311,18 +341,8 @@ class BasicTransformerBlock(nn.Module): ...@@ -311,18 +341,8 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None: if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory # "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: ff_output = _chunked_feed_forward(
raise ValueError( self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
) )
else: else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale) ff_output = self.ff(norm_hidden_states, scale=lora_scale)
...@@ -339,6 +359,137 @@ class BasicTransformerBlock(nn.Module): ...@@ -339,6 +359,137 @@ class BasicTransformerBlock(nn.Module):
return hidden_states return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class FeedForward(nn.Module): class FeedForward(nn.Module):
r""" r"""
A feed-forward layer. A feed-forward layer.
......
...@@ -18,7 +18,7 @@ import torch.nn as nn ...@@ -18,7 +18,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook from ..utils.accelerate_utils import apply_forward_hook
from .autoencoder_kl import AutoencoderKLOutput from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
...@@ -19,7 +18,6 @@ import torch.nn as nn ...@@ -19,7 +18,6 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import ( from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
...@@ -28,24 +26,11 @@ from .attention_processor import ( ...@@ -28,24 +26,11 @@ from .attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = MidBlockTemporalDecoder(
num_layers=self.layers_per_block,
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
attention_head_dim=block_out_channels[-1],
)
# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlockTemporalDecoder(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = torch.nn.Conv2d(
in_channels=block_out_channels[0],
out_channels=out_channels,
kernel_size=3,
padding=1,
)
conv_out_kernel_size = (3, 1, 1)
padding = [int(k // 2) for k in conv_out_kernel_size]
self.time_conv_out = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=conv_out_kernel_size,
padding=padding,
)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
num_frames: int = 1,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
)
else:
# middle
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, image_only_indicator=image_only_indicator)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
batch_frames, channels, height, width = sample.shape
batch_size = batch_frames // num_frames
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
sample = self.time_conv_out(sample)
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
double_z=True,
)
# pass init params to Decoder
self.decoder = TemporalDecoder(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
num_frames: int,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
batch_size = z.shape[0] // num_frames
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, num_frames=num_frames).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
from dataclasses import dataclass
from ..utils import BaseOutput
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
...@@ -165,7 +165,10 @@ class Upsample2D(nn.Module): ...@@ -165,7 +165,10 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv self.Conv2d_0 = conv
def forward( def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0 self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
...@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module): ...@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module):
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d( inverse_conv = F.conv_transpose2d(
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
) )
output = upfirdn2d_native( output = upfirdn2d_native(
...@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module): ...@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device) indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel weight[indices, indices] = kernel
...@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module): ...@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device) indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel weight[indices, indices] = kernel
...@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module): ...@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None self.conv_shortcut = None
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = conv_cls( self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
) )
def forward( def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0 self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
hidden_states = input_tensor hidden_states = input_tensor
...@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module): ...@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module):
def upsample_2d( def upsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
...@@ -910,7 +942,10 @@ def upsample_2d( ...@@ -910,7 +942,10 @@ def upsample_2d(
def downsample_2d( def downsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
...@@ -946,13 +981,20 @@ def downsample_2d( ...@@ -946,13 +981,20 @@ def downsample_2d(
kernel = kernel * gain kernel = kernel * gain
pad_value = kernel.shape[0] - factor pad_value = kernel.shape[0] - factor
output = upfirdn2d_native( output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
) )
return output return output
def upfirdn2d_native( def upfirdn2d_native(
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0) tensor: torch.Tensor,
kernel: torch.Tensor,
up: int = 1,
down: int = 1,
pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor: ) -> torch.Tensor:
up_x = up_y = up up_x = up_y = up
down_x = down_y = down down_x = down_y = down
...@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module): ...@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
""" """
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32): def __init__(
self,
in_dim: int,
out_dim: Optional[int] = None,
dropout: float = 0.0,
norm_num_groups: int = 32,
):
super().__init__() super().__init__()
out_dim = out_dim or in_dim out_dim = out_dim or in_dim
self.in_dim = in_dim self.in_dim = in_dim
...@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module): ...@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module):
# conv layers # conv layers
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim), nn.GroupNorm(norm_num_groups, out_dim),
...@@ -1058,3 +1108,261 @@ class TemporalConvLayer(nn.Module): ...@@ -1058,3 +1108,261 @@ class TemporalConvLayer(nn.Module):
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
) )
return hidden_states return hidden_states
class TemporalResnetBlock(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
temb_channels: int = 512,
eps: float = 1e-6,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
kernel_size = (3, 1, 1)
padding = [k // 2 for k in kernel_size]
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(0.0)
self.conv2 = nn.Conv3d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
self.nonlinearity = get_activation("silu")
self.use_in_shortcut = self.in_channels != out_channels
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = nn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
r"""
A SpatioTemporal Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
The merge strategy to use for the temporal mixing.
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
If `True`, switch the spatial and temporal mixing.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
temb_channels: int = 512,
eps: float = 1e-6,
temporal_eps: Optional[float] = None,
merge_factor: float = 0.5,
merge_strategy="learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=eps,
)
self.temporal_res_block = TemporalResnetBlock(
in_channels=out_channels if out_channels is not None else in_channels,
out_channels=out_channels if out_channels is not None else in_channels,
temb_channels=temb_channels,
eps=temporal_eps if temporal_eps is not None else eps,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
):
num_frames = image_only_indicator.shape[-1]
hidden_states = self.spatial_res_block(hidden_states, temb)
batch_frames, channels, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states_mix = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
hidden_states = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
if temb is not None:
temb = temb.reshape(batch_size, num_frames, -1)
hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
image_only_indicator=image_only_indicator,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return hidden_states
class AlphaBlender(nn.Module):
r"""
A module to blend spatial and temporal features.
Parameters:
alpha (`float`): The initial value of the blending factor.
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
The merge strategy to use for the temporal mixing.
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
If `True`, switch the spatial and temporal mixing.
"""
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.merge_strategy = merge_strategy
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
if merge_strategy not in self.strategies:
raise ValueError(f"merge_strategy needs to be in {self.strategies}")
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
if image_only_indicator is None:
raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.sigmoid(self.mix_factor)[..., None],
)
# (batch, channel, frames, height, width)
if ndims == 5:
alpha = alpha[:, None, :, None, None]
# (batch*frames, height*width, channels)
elif ndims == 3:
alpha = alpha.reshape(-1)[:, None, None]
else:
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
alpha = alpha.to(x_spatial.dtype)
if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x
...@@ -19,8 +19,10 @@ from torch import nn ...@@ -19,8 +19,10 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .resnet import AlphaBlender
@dataclass @dataclass
...@@ -195,3 +197,183 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -195,3 +197,183 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
return (output,) return (output,)
return TransformerTemporalModelOutput(sample=output) return TransformerTemporalModelOutput(sample=output)
class TransformerSpatioTemporalModel(nn.Module):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
out_channels (`int`, *optional*):
The number of channels in the output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
# 2. Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for d in range(num_layers)
]
)
time_mix_inner_dim = inner_dim
self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
inner_dim,
time_mix_inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for _ in range(num_layers)
]
)
time_embed_dim = in_channels * 4
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = nn.Linear(inner_dim, in_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input hidden_states.
num_frames (`int`):
The number of frames to be processed per batch. This is used to reshape the hidden states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, _, height, width = hidden_states.shape
num_frames = image_only_indicator.shape[-1]
batch_size = batch_frames // num_frames
time_context = encoder_hidden_states
time_context_first_timestep = time_context[None, :].reshape(
batch_size, num_frames, -1, time_context.shape[-1]
)[:, 0]
time_context = time_context_first_timestep[None, :].broadcast_to(
height * width, batch_size, 1, time_context.shape[-1]
)
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
hidden_states_mix = temporal_block(
hidden_states_mix,
num_frames=num_frames,
encoder_hidden_states=time_context,
)
hidden_states = self.time_mixer(
x_spatial=hidden_states,
x_temporal=hidden_states_mix,
image_only_indicator=image_only_indicator,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)
This diff is collapsed.
This diff is collapsed.
...@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version ...@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .activations import get_activation from .activations import get_activation
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
get_up_block,
)
@dataclass @dataclass
...@@ -274,7 +279,9 @@ class Decoder(nn.Module): ...@@ -274,7 +279,9 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
...@@ -292,14 +299,20 @@ class Decoder(nn.Module): ...@@ -292,14 +299,20 @@ class Decoder(nn.Module):
if is_torch_version(">=", "1.11.0"): if is_torch_version(">=", "1.11.0"):
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
) )
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
) )
else: else:
# middle # middle
...@@ -540,7 +553,10 @@ class MaskConditionDecoder(nn.Module): ...@@ -540,7 +553,10 @@ class MaskConditionDecoder(nn.Module):
if is_torch_version(">=", "1.11.0"): if is_torch_version(">=", "1.11.0"):
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
) )
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
...@@ -548,7 +564,10 @@ class MaskConditionDecoder(nn.Module): ...@@ -548,7 +564,10 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None: if image is not None and mask is not None:
masked_image = (1 - mask) * image masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint( im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
) )
# up # up
...@@ -558,7 +577,10 @@ class MaskConditionDecoder(nn.Module): ...@@ -558,7 +577,10 @@ class MaskConditionDecoder(nn.Module):
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_) sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
) )
if image is not None and mask is not None: if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
...@@ -573,7 +595,9 @@ class MaskConditionDecoder(nn.Module): ...@@ -573,7 +595,9 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None: if image is not None and mask is not None:
masked_image = (1 - mask) * image masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint( im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask create_custom_forward(self.condition_encoder),
masked_image,
mask,
) )
# up # up
...@@ -754,7 +778,10 @@ class DiagonalGaussianDistribution(object): ...@@ -754,7 +778,10 @@ class DiagonalGaussianDistribution(object):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype # make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor( sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
) )
x = self.mean + self.std * sample x = self.mean + self.std * sample
return x return x
...@@ -764,7 +791,10 @@ class DiagonalGaussianDistribution(object): ...@@ -764,7 +791,10 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
...@@ -779,7 +809,10 @@ class DiagonalGaussianDistribution(object): ...@@ -779,7 +809,10 @@ class DiagonalGaussianDistribution(object):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor: def mode(self) -> torch.Tensor:
return self.mean return self.mean
...@@ -820,7 +853,16 @@ class EncoderTiny(nn.Module): ...@@ -820,7 +853,16 @@ class EncoderTiny(nn.Module):
if i == 0: if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
else: else:
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False)) layers.append(
nn.Conv2d(
num_channels,
num_channels,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
for _ in range(num_block): for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
...@@ -899,7 +941,15 @@ class DecoderTiny(nn.Module): ...@@ -899,7 +941,15 @@ class DecoderTiny(nn.Module):
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor)) layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels conv_out_channel = num_channels if not is_final_block else out_channels
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block)) layers.append(
nn.Conv2d(
num_channels,
conv_out_channel,
kernel_size=3,
padding=1,
bias=is_final_block,
)
)
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
......
...@@ -17,7 +17,12 @@ from ..utils import ( ...@@ -17,7 +17,12 @@ from ..utils import (
# These modules contain pipelines from multiple libraries/frameworks # These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {} _dummy_objects = {}
_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []} _import_structure = {
"controlnet": [],
"latent_diffusion": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
}
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -39,7 +44,11 @@ else: ...@@ -39,7 +44,11 @@ else:
_import_structure["dit"] = ["DiTPipeline"] _import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] _import_structure["pipeline_utils"] = [
"AudioPipelineOutput",
"DiffusionPipeline",
"ImagePipelineOutput",
]
_import_structure["pndm"] = ["PNDMPipeline"] _import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"] _import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
...@@ -61,7 +70,10 @@ except OptionalDependencyNotAvailable: ...@@ -61,7 +70,10 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] _import_structure["alt_diffusion"] = [
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
]
_import_structure["animatediff"] = ["AnimateDiffPipeline"] _import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [ _import_structure["audioldm2"] = [
...@@ -110,7 +122,10 @@ else: ...@@ -110,7 +122,10 @@ else:
"KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline", "KandinskyV22PriorPipeline",
] ]
_import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"] _import_structure["kandinsky3"] = [
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
]
_import_structure["latent_consistency_models"] = [ _import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline", "LatentConsistencyModelPipeline",
...@@ -150,6 +165,7 @@ else: ...@@ -150,6 +165,7 @@ else:
] ]
) )
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
_import_structure["stable_diffusion_xl"].extend( _import_structure["stable_diffusion_xl"].extend(
[ [
"StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLImg2ImgPipeline",
...@@ -158,7 +174,10 @@ else: ...@@ -158,7 +174,10 @@ else:
"StableDiffusionXLPipeline", "StableDiffusionXLPipeline",
] ]
) )
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] _import_structure["t2i_adapter"] = [
"StableDiffusionAdapterPipeline",
"StableDiffusionXLAdapterPipeline",
]
_import_structure["text_to_video_synthesis"] = [ _import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline", "TextToVideoSDPipeline",
"TextToVideoZeroPipeline", "TextToVideoZeroPipeline",
...@@ -216,7 +235,9 @@ try: ...@@ -216,7 +235,9 @@ try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 from ..utils import (
dummy_torch_and_transformers_and_k_diffusion_objects,
)
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else: else:
...@@ -259,7 +280,10 @@ except OptionalDependencyNotAvailable: ...@@ -259,7 +280,10 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else: else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] _import_structure["spectrogram_diffusion"] = [
"MidiProcessor",
"SpectrogramDiffusionPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
...@@ -269,7 +293,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -269,7 +293,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_pt_objects import * # noqa F403 from ..utils.dummy_pt_objects import * # noqa F403
else: else:
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image from .auto_pipeline import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
)
from .consistency_models import ConsistencyModelPipeline from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline from .ddim import DDIMPipeline
...@@ -277,7 +305,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -277,7 +305,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .dit import DiTPipeline from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
ImagePipelineOutput,
)
from .pndm import PNDMPipeline from .pndm import PNDMPipeline
from .repaint import RePaintPipeline from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline from .score_sde_ve import ScoreSdeVePipeline
...@@ -300,7 +332,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -300,7 +332,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .animatediff import AnimateDiffPipeline from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel from .audioldm2 import (
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
)
from .blip_diffusion import BlipDiffusionPipeline from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import ( from .controlnet import (
BlipDiffusionControlNetPipeline, BlipDiffusionControlNetPipeline,
...@@ -344,7 +380,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -344,7 +380,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3Img2ImgPipeline, Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline, Kandinsky3Pipeline,
) )
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_consistency_models import (
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
...@@ -383,7 +422,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -383,7 +422,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
) )
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import (
StableDiffusionAdapterPipeline,
StableDiffusionXLAdapterPipeline,
)
from .text_to_video_synthesis import ( from .text_to_video_synthesis import (
TextToVideoSDPipeline, TextToVideoSDPipeline,
TextToVideoZeroPipeline, TextToVideoZeroPipeline,
...@@ -473,7 +516,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -473,7 +516,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else: else:
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline from .spectrogram_diffusion import (
MidiProcessor,
SpectrogramDiffusionPipeline,
)
else: else:
import sys import sys
......
...@@ -55,7 +55,9 @@ try: ...@@ -55,7 +55,9 @@ try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline from ...utils.dummy_torch_and_transformers_objects import (
StableDiffusionImageVariationPipeline,
)
_dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline}) _dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline})
else: else:
...@@ -90,7 +92,9 @@ try: ...@@ -90,7 +92,9 @@ try:
): ):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 from ...utils import (
dummy_torch_and_transformers_and_k_diffusion_objects,
)
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else: else:
...@@ -137,18 +141,32 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -137,18 +141,32 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_attend_and_excite import (
StableDiffusionAttendAndExcitePipeline,
)
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline from .pipeline_stable_diffusion_gligen_text_image import (
StableDiffusionGLIGENTextImagePipeline,
)
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_inpaint_legacy import (
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline StableDiffusionInpaintPipelineLegacy,
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline )
from .pipeline_stable_diffusion_instruct_pix2pix import (
StableDiffusionInstructPix2PixPipeline,
)
from .pipeline_stable_diffusion_latent_upscale import (
StableDiffusionLatentUpscalePipeline,
)
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline from .pipeline_stable_diffusion_model_editing import (
StableDiffusionModelEditingPipeline,
)
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline from .pipeline_stable_diffusion_paradigms import (
StableDiffusionParadigmsPipeline,
)
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip import StableUnCLIPPipeline
...@@ -160,9 +178,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -160,9 +178,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline from ...utils.dummy_torch_and_transformers_objects import (
StableDiffusionImageVariationPipeline,
)
else: else:
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline from .pipeline_stable_diffusion_image_variation import (
StableDiffusionImageVariationPipeline,
)
try: try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")):
...@@ -174,9 +196,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -174,9 +196,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPix2PixZeroPipeline, StableDiffusionPix2PixZeroPipeline,
) )
else: else:
from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline from .pipeline_stable_diffusion_depth2img import (
StableDiffusionDepth2ImgPipeline,
)
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline from .pipeline_stable_diffusion_pix2pix_zero import (
StableDiffusionPix2PixZeroPipeline,
)
try: try:
if not ( if not (
...@@ -189,7 +215,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -189,7 +215,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else: else:
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline from .pipeline_stable_diffusion_k_diffusion import (
StableDiffusionKDiffusionPipeline,
)
try: try:
if not (is_transformers_available() and is_onnx_available()): if not (is_transformers_available() and is_onnx_available()):
...@@ -197,11 +225,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -197,11 +225,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import * from ...utils.dummy_onnx_objects import *
else: else:
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion import (
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline OnnxStableDiffusionPipeline,
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline StableDiffusionOnnxPipeline,
from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy )
from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline from .pipeline_onnx_stable_diffusion_img2img import (
OnnxStableDiffusionImg2ImgPipeline,
)
from .pipeline_onnx_stable_diffusion_inpaint import (
OnnxStableDiffusionInpaintPipeline,
)
from .pipeline_onnx_stable_diffusion_inpaint_legacy import (
OnnxStableDiffusionInpaintPipelineLegacy,
)
from .pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
try: try:
if not (is_transformers_available() and is_flax_available()): if not (is_transformers_available() and is_flax_available()):
...@@ -210,8 +249,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -210,8 +249,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_flax_objects import * from ...utils.dummy_flax_objects import *
else: else:
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_img2img import (
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline FlaxStableDiffusionImg2ImgPipeline,
)
from .pipeline_flax_stable_diffusion_inpaint import (
FlaxStableDiffusionInpaintPipeline,
)
from .pipeline_output import FlaxStableDiffusionPipelineOutput from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
BaseOutput,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_stable_video_diffusion": [
"StableVideoDiffusionPipeline",
"StableVideoDiffusionPipelineOutput",
],
}
)
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_stable_video_diffusion import (
StableVideoDiffusionPipeline,
StableVideoDiffusionPipelineOutput,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment