Unverified Commit 0d1d267b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Allegro T2V (#9736)



* update

* refactor transformer part 1

* refactor part 2

* refactor part 3

* make style

* refactor part 4; modeling tests

* make style

* refactor part 5

* refactor part 6

* gradient checkpointing

* pipeline tests (broken atm)

* update

* add coauthor
Co-Authored-By: default avatarHuan Yang <hyang@fastmail.com>

* refactor part 7

* add docs

* make style

* add coauthor
Co-Authored-By: default avatarYiYi Xu <yixu310@gmail.com>

* make fix-copies

* undo unrelated change

* revert changes to embeddings, normalization, transformer

* refactor part 8

* make style

* refactor part 9

* make style

* fix

* apply suggestions from review

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update example

* remove attention mask for self-attention

* update

* copied from

* update

* update

---------
Co-authored-by: default avatarHuan Yang <hyang@fastmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent c5376c56
...@@ -252,6 +252,8 @@ ...@@ -252,6 +252,8 @@
title: SparseControlNetModel title: SparseControlNetModel
title: ControlNets title: ControlNets
- sections: - sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d - local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d - local: api/models/cogvideox_transformer3d
...@@ -300,6 +302,8 @@ ...@@ -300,6 +302,8 @@
- sections: - sections:
- local: api/models/autoencoderkl - local: api/models/autoencoderkl
title: AutoencoderKL title: AutoencoderKL
- local: api/models/autoencoderkl_allegro
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox - local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX title: AutoencoderKLCogVideoX
- local: api/models/asymmetricautoencoderkl - local: api/models/asymmetricautoencoderkl
...@@ -318,6 +322,8 @@ ...@@ -318,6 +322,8 @@
sections: sections:
- local: api/pipelines/overview - local: api/pipelines/overview
title: Overview title: Overview
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/amused - local: api/pipelines/amused
title: aMUSEd title: aMUSEd
- local: api/pipelines/animatediff - local: api/pipelines/animatediff
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AllegroTransformer3DModel
A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AllegroTransformer3DModel
vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## AllegroTransformer3DModel
[[autodoc]] AllegroTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLAllegro
The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLAllegro
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLAllegro
[[autodoc]] AutoencoderKLAllegro
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Allegro
[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
The abstract from the paper is:
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## AllegroPipeline
[[autodoc]] AllegroPipeline
- all
- __call__
## AllegroPipelineOutput
[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput
...@@ -77,9 +77,11 @@ except OptionalDependencyNotAvailable: ...@@ -77,9 +77,11 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["models"].extend( _import_structure["models"].extend(
[ [
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL", "AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel", "AuraFlowTransformer2DModel",
"AutoencoderKL", "AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX", "AutoencoderKLCogVideoX",
"AutoencoderKLTemporalDecoder", "AutoencoderKLTemporalDecoder",
"AutoencoderOobleck", "AutoencoderOobleck",
...@@ -237,6 +239,7 @@ except OptionalDependencyNotAvailable: ...@@ -237,6 +239,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipelines"].extend( _import_structure["pipelines"].extend(
[ [
"AllegroPipeline",
"AltDiffusionImg2ImgPipeline", "AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline", "AltDiffusionPipeline",
"AmusedImg2ImgPipeline", "AmusedImg2ImgPipeline",
...@@ -556,9 +559,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -556,9 +559,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_pt_objects import * # noqa F403 from .utils.dummy_pt_objects import * # noqa F403
else: else:
from .models import ( from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
AutoencoderKL, AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderOobleck, AutoencoderOobleck,
...@@ -697,6 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -697,6 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .pipelines import ( from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline, AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline, AltDiffusionPipeline,
AmusedImg2ImgPipeline, AmusedImg2ImgPipeline,
......
...@@ -28,6 +28,7 @@ if is_torch_available(): ...@@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
...@@ -54,6 +55,7 @@ if is_torch_available(): ...@@ -54,6 +55,7 @@ if is_torch_available():
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
...@@ -81,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -81,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoders import ( from .autoencoders import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AutoencoderKL, AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderOobleck, AutoencoderOobleck,
...@@ -97,6 +100,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -97,6 +100,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .embeddings import ImageProjection from .embeddings import ImageProjection
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .transformers import ( from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel, CogView3PlusTransformer2DModel,
......
...@@ -1521,6 +1521,100 @@ class FusedJointAttnProcessor2_0: ...@@ -1521,6 +1521,100 @@ class FusedJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class AllegroAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply RoPE if needed
if image_rotary_emb is not None and not attn.is_cross_attention:
from .embeddings import apply_rotary_emb_allegro
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AuraFlowAttnProcessor2_0: class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow.""" """Attention processor used typically in processing Aura Flow."""
......
from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_oobleck import AutoencoderOobleck
......
# Copyright 2024 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import Attention, SpatialNorm
from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from ..downsampling import Downsample2D
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
class AllegroTemporalConvLayer(nn.Module):
r"""
Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
"""
def __init__(
self,
in_dim: int,
out_dim: Optional[int] = None,
dropout: float = 0.0,
norm_num_groups: int = 32,
up_sample: bool = False,
down_sample: bool = False,
stride: int = 1,
) -> None:
super().__init__()
out_dim = out_dim or in_dim
pad_h = pad_w = int((stride - 1) * 0.5)
pad_t = 0
self.down_sample = down_sample
self.up_sample = up_sample
if down_sample:
self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)),
)
elif up_sample:
self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)),
)
else:
self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
)
self.conv3 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
)
self.conv4 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
nn.SiLU(),
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
)
@staticmethod
def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2)
return hidden_states
def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
if self.down_sample:
identity = hidden_states[:, :, ::2]
elif self.up_sample:
identity = hidden_states.repeat_interleave(2, dim=2)
else:
identity = hidden_states
if self.down_sample or self.up_sample:
hidden_states = self.conv1(hidden_states)
else:
hidden_states = self._pad_temporal_dim(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.up_sample:
hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3)
hidden_states = self._pad_temporal_dim(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self._pad_temporal_dim(hidden_states)
hidden_states = self.conv3(hidden_states)
hidden_states = self._pad_temporal_dim(hidden_states)
hidden_states = self.conv4(hidden_states)
hidden_states = identity + hidden_states
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
return hidden_states
class AllegroDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
spatial_downsample: bool = True,
temporal_downsample: bool = False,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
temp_convs = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
AllegroTemporalConvLayer(
out_channels,
out_channels,
dropout=0.1,
norm_num_groups=resnet_groups,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
if temporal_downsample:
self.temp_convs_down = AllegroTemporalConvLayer(
out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3
)
self.add_temp_downsample = temporal_downsample
if spatial_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = temp_conv(hidden_states, batch_size=batch_size)
if self.add_temp_downsample:
hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
return hidden_states
class AllegroUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
temb_channels: Optional[int] = None,
):
super().__init__()
resnets = []
temp_convs = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
AllegroTemporalConvLayer(
out_channels,
out_channels,
dropout=0.1,
norm_num_groups=resnet_groups,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
self.add_temp_upsample = temporal_upsample
if temporal_upsample:
self.temp_conv_up = AllegroTemporalConvLayer(
out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3
)
if spatial_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = temp_conv(hidden_states, batch_size=batch_size)
if self.add_temp_upsample:
hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
return hidden_states
class AllegroMidBlock3DConv(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
temp_convs = [
AllegroTemporalConvLayer(
in_channels,
in_channels,
dropout=0.1,
norm_num_groups=resnet_groups,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
AllegroTemporalConvLayer(
in_channels,
in_channels,
dropout=0.1,
norm_num_groups=resnet_groups,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
self.attentions = nn.ModuleList(attentions)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.resnets[0](hidden_states, temb=None)
hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size)
for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]):
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb=None)
hidden_states = temp_conv(hidden_states, batch_size=batch_size)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
return hidden_states
class AllegroEncoder3D(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = (
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
):
super().__init__()
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.temp_conv_in = nn.Conv3d(
in_channels=block_out_channels[0],
out_channels=block_out_channels[0],
kernel_size=(3, 1, 1),
padding=(1, 0, 0),
)
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if down_block_type == "AllegroDownBlock3D":
down_block = AllegroDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
spatial_downsample=not is_final_block,
temporal_downsample=temporal_downsample_blocks[i],
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
)
else:
raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`")
self.down_blocks.append(down_block)
# mid
self.mid_block = AllegroMidBlock3DConv(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor) -> torch.Tensor:
batch_size = sample.shape[0]
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
sample = self.conv_in(sample)
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
residual = sample
sample = self.temp_conv_in(sample)
sample = sample + residual
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Down blocks
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else:
# Down blocks
for down_block in self.down_blocks:
sample = down_block(sample)
# Mid block
sample = self.mid_block(sample)
# Post process
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
residual = sample
sample = self.temp_conv_out(sample)
sample = sample + residual
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
sample = self.conv_out(sample)
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
return sample
class AllegroDecoder3D(nn.Module):
def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = (
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
),
temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
):
super().__init__()
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = AllegroMidBlock3DConv(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if up_block_type == "AllegroUpBlock3D":
up_block = AllegroUpBlock3D(
num_layers=layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
spatial_upsample=not is_final_block,
temporal_upsample=temporal_upsample_blocks[i],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
else:
raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`")
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, sample: torch.Tensor) -> torch.Tensor:
batch_size = sample.shape[0]
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
sample = self.conv_in(sample)
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
residual = sample
sample = self.temp_conv_in(sample)
sample = sample + residual
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
# Up blocks
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# Mid block
sample = self.mid_block(sample)
sample = sample.to(upscale_dtype)
# Up blocks
for up_block in self.up_blocks:
sample = up_block(sample)
# Post process
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
residual = sample
sample = self.temp_conv_out(sample)
sample = sample + residual
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
sample = self.conv_out(sample)
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
return sample
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, defaults to `3`):
Number of channels in the input image.
out_channels (int, defaults to `3`):
Number of channels in the output.
down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
Tuple of strings denoting which types of down blocks to use.
up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
Tuple of strings denoting which types of up blocks to use.
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
Tuple of integers denoting number of output channels in each block.
temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
Tuple of booleans denoting which blocks to enable temporal downsampling in.
latent_channels (`int`, defaults to `4`):
Number of channels in latents.
layers_per_block (`int`, defaults to `2`):
Number of resnet or attention or temporal convolution layers per down/up block.
act_fn (`str`, defaults to `"silu"`):
The activation function to use.
norm_num_groups (`int`, defaults to `32`):
Number of groups to use in normalization layers.
temporal_compression_ratio (`int`, defaults to `4`):
Ratio by which temporal dimension of samples are compressed.
sample_size (`int`, defaults to `320`):
Default latent size.
scaling_factor (`float`, defaults to `0.13235`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = (
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
),
up_block_types: Tuple[str, ...] = (
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
latent_channels: int = 4,
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
sample_size: int = 320,
scaling_factor: float = 0.13,
force_upcast: bool = True,
) -> None:
super().__init__()
self.encoder = AllegroEncoder3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
temporal_downsample_blocks=temporal_downsample_blocks,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
self.decoder = AllegroDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
temporal_upsample_blocks=temporal_upsample_blocks,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
# TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need
# to use a specific parameter here or in other VAEs.
self.use_slicing = False
self.use_tiling = False
self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
self.tile_overlap_t = 8
self.tile_overlap_h = 120
self.tile_overlap_w = 80
sample_frames = 24
self.kernel = (sample_frames, sample_size, sample_size)
self.stride = (
sample_frames - self.tile_overlap_t,
sample_size - self.tile_overlap_h,
sample_size - self.tile_overlap_w,
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
module.gradient_checkpointing = value
def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
if self.use_tiling:
return self.tiled_encode(x)
raise NotImplementedError("Encoding without tiling has not been implemented yet.")
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`):
Input batch of videos.
return_dict (`bool`, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
# TODO(aryan): refactor tiling implementation
# if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
if self.use_tiling:
return self.tiled_decode(z)
raise NotImplementedError("Decoding without tiling has not been implemented yet.")
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of videos.
Args:
z (`torch.Tensor`):
Input batch of latent vectors.
return_dict (`bool`, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
local_batch_size = 1
rs = self.spatial_compression_ratio
rt = self.config.temporal_compression_ratio
batch_size, num_channels, num_frames, height, width = x.shape
output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1
output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1
output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1
count = 0
output_latent = x.new_zeros(
(
output_num_frames * output_height * output_width,
2 * self.config.latent_channels,
self.kernel[0] // rt,
self.kernel[1] // rs,
self.kernel[2] // rs,
)
)
vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]))
for i in range(output_num_frames):
for j in range(output_height):
for k in range(output_width):
n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
vae_batch_input[count % local_batch_size] = video_cube
if (
count % local_batch_size == local_batch_size - 1
or count == output_num_frames * output_height * output_width - 1
):
latent = self.encoder(vae_batch_input)
if (
count == output_num_frames * output_height * output_width - 1
and count % local_batch_size != local_batch_size - 1
):
output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1]
else:
output_latent[count - local_batch_size + 1 : count + 1] = latent
vae_batch_input = x.new_zeros(
(local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])
)
count += 1
latent = x.new_zeros(
(batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)
)
output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
output_overlap = (
output_kernel[0] - output_stride[0],
output_kernel[1] - output_stride[1],
output_kernel[2] - output_stride[2],
)
for i in range(output_num_frames):
n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0]
for j in range(output_height):
h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1]
for k in range(output_width):
w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2]
latent_mean = _prepare_for_blend(
(i, output_num_frames, output_overlap[0]),
(j, output_height, output_overlap[1]),
(k, output_width, output_overlap[2]),
output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0),
)
latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean
latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1)
latent = self.quant_conv(latent)
latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
return latent
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
local_batch_size = 1
rs = self.spatial_compression_ratio
rt = self.config.temporal_compression_ratio
latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
batch_size, num_channels, num_frames, height, width = z.shape
## post quant conv (a mapping)
z = z.permute(0, 2, 1, 3, 4).flatten(0, 1)
z = self.post_quant_conv(z)
z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1
output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1
output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1
count = 0
decoded_videos = z.new_zeros(
(
output_num_frames * output_height * output_width,
self.config.out_channels,
self.kernel[0],
self.kernel[1],
self.kernel[2],
)
)
vae_batch_input = z.new_zeros(
(local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
)
for i in range(output_num_frames):
for j in range(output_height):
for k in range(output_width):
n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0]
h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1]
w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2]
current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
vae_batch_input[count % local_batch_size] = current_latent
if (
count % local_batch_size == local_batch_size - 1
or count == output_num_frames * output_height * output_width - 1
):
current_video = self.decoder(vae_batch_input)
if (
count == output_num_frames * output_height * output_width - 1
and count % local_batch_size != local_batch_size - 1
):
decoded_videos[count - count % local_batch_size :] = current_video[
: count % local_batch_size + 1
]
else:
decoded_videos[count - local_batch_size + 1 : count + 1] = current_video
vae_batch_input = z.new_zeros(
(local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
)
count += 1
video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs))
video_overlap = (
self.kernel[0] - self.stride[0],
self.kernel[1] - self.stride[1],
self.kernel[2] - self.stride[2],
)
for i in range(output_num_frames):
n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
for j in range(output_height):
h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
for k in range(output_width):
w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
out_video_blend = _prepare_for_blend(
(i, output_num_frames, video_overlap[0]),
(j, output_height, video_overlap[1]),
(k, output_width, video_overlap[2]),
decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0),
)
video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
video = video.permute(0, 2, 1, 3, 4).contiguous()
return video
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
encoder_local_batch_size: int = 2,
decoder_local_batch_size: int = 2,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
PyTorch random number generator.
encoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the encoder's batch inference.
decoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the decoder's batch inference.
"""
x = sample
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def _prepare_for_blend(n_param, h_param, w_param, x):
# TODO(aryan): refactor
n, n_max, overlap_n = n_param
h, h_max, overlap_h = h_param
w, w_max, overlap_w = w_param
if overlap_n > 0:
if n > 0: # the head overlap part decays from 0 to 1
x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * (
torch.arange(0, overlap_n).float().to(x.device) / overlap_n
).reshape(overlap_n, 1, 1)
if n < n_max - 1: # the tail overlap part decays from 1 to 0
x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * (
1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n
).reshape(overlap_n, 1, 1)
if h > 0:
x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * (
torch.arange(0, overlap_h).float().to(x.device) / overlap_h
).reshape(overlap_h, 1)
if h < h_max - 1:
x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * (
1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h
).reshape(overlap_h, 1)
if w > 0:
x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * (
torch.arange(0, overlap_w).float().to(x.device) / overlap_w
)
if w < w_max - 1:
x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * (
1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w
)
return x
...@@ -564,6 +564,42 @@ def get_3d_rotary_pos_embed( ...@@ -564,6 +564,42 @@ def get_3d_rotary_pos_embed(
return cos, sin return cos, sin
def get_3d_rotary_pos_embed_allegro(
embed_dim,
crops_coords,
grid_size,
temporal_size,
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
theta: int = 10000,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(aryan): docs
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 3
dim_h = embed_dim // 3
dim_w = embed_dim // 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(
dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(
dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
)
freqs_w = get_1d_rotary_pos_embed(
dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
)
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
""" """
RoPE for image tokens with 2d structure. RoPE for image tokens with 2d structure.
...@@ -684,7 +720,7 @@ def get_1d_rotary_pos_embed( ...@@ -684,7 +720,7 @@ def get_1d_rotary_pos_embed(
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
elif use_real: elif use_real:
# stable audio # stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
...@@ -743,6 +779,24 @@ def apply_rotary_emb( ...@@ -743,6 +779,24 @@ def apply_rotary_emb(
return x_out.type_as(x) return x_out.type_as(x)
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
# TODO(aryan): rewrite
def apply_1d_rope(tokens, pos, cos, sin):
cos = F.embedding(pos, cos)[:, None, :, :]
sin = F.embedding(pos, sin)[:, None, :, :]
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
tokens_rotated = torch.cat((-x2, x1), dim=-1)
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
t, h, w = x.chunk(3, dim=-1)
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
x = torch.cat([t, h, w], dim=-1)
return x
class FluxPosEmbed(nn.Module): class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]): def __init__(self, theta: int, axes_dim: List[int]):
......
...@@ -22,10 +22,7 @@ import torch.nn.functional as F ...@@ -22,10 +22,7 @@ import torch.nn.functional as F
from ..utils import is_torch_version from ..utils import is_torch_version
from .activations import get_activation from .activations import get_activation
from .embeddings import ( from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
CombinedTimestepLabelEmbeddings,
PixArtAlphaCombinedTimestepSizeEmbeddings,
)
class AdaLayerNorm(nn.Module): class AdaLayerNorm(nn.Module):
...@@ -266,6 +263,7 @@ class AdaLayerNormSingle(nn.Module): ...@@ -266,6 +263,7 @@ class AdaLayerNormSingle(nn.Module):
hidden_dtype: Optional[torch.dtype] = None, hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here. # No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep return self.linear(self.silu(embedded_timestep)), embedded_timestep
......
...@@ -14,6 +14,7 @@ if is_torch_available(): ...@@ -14,6 +14,7 @@ if is_torch_available():
from .stable_audio_transformer import StableAudioDiTModel from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel
......
# Copyright 2024 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__)
@maybe_allow_in_graph
class AllegroTransformerBlock(nn.Module):
r"""
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
cross_attention_dim (`int`, defaults to `2304`):
The dimension of the cross attention features.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
only_cross_attention (`bool`, defaults to `False`):
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
):
super().__init__()
# 1. Self Attention
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=AllegroAttnProcessor2_0(),
)
# 2. Cross Attention
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
processor=AllegroAttnProcessor2_0(),
)
# 3. Feed Forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
)
# 4. Scale-shift
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb=None,
) -> torch.Tensor:
# 0. Self-Attention
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
image_rotary_emb=None,
)
hidden_states = attn_output + hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
# TODO(aryan): maybe following line is not required
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
"""
A 3D Transformer model for video-like data.
Args:
patch_size (`int`, defaults to `2`):
The size of spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `96`):
The number of channels in each head.
in_channels (`int`, defaults to `4`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `4`):
The number of channels in the output.
num_layers (`int`, defaults to `32`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
cross_attention_dim (`int`, defaults to `2304`):
The dimension of the cross attention features.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_height (`int`, defaults to `90`):
The height of the input latents.
sample_width (`int`, defaults to `160`):
The width of the input latents.
sample_frames (`int`, defaults to `22`):
The number of frames in the input latents.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
norm_elementwise_affine (`bool`, defaults to `False`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value to use in normalization layers.
caption_channels (`int`, defaults to `4096`):
Number of channels to use for projecting the caption embeddings.
interpolation_scale_h (`float`, defaults to `2.0`):
Scaling factor to apply in 3D positional embeddings across height dimension.
interpolation_scale_w (`float`, defaults to `2.0`):
Scaling factor to apply in 3D positional embeddings across width dimension.
interpolation_scale_t (`float`, defaults to `2.2`):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
@register_to_config
def __init__(
self,
patch_size: int = 2,
patch_size_t: int = 1,
num_attention_heads: int = 24,
attention_head_dim: int = 96,
in_channels: int = 4,
out_channels: int = 4,
num_layers: int = 32,
dropout: float = 0.0,
cross_attention_dim: int = 2304,
attention_bias: bool = True,
sample_height: int = 90,
sample_width: int = 160,
sample_frames: int = 22,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 4096,
interpolation_scale_h: float = 2.0,
interpolation_scale_w: float = 2.0,
interpolation_scale_t: float = 2.2,
):
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
interpolation_scale_t = (
interpolation_scale_t
if interpolation_scale_t is not None
else ((sample_frames - 1) // 16 + 1)
if sample_frames % 2 == 1
else sample_frames // 16
)
interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
# 1. Patch embedding
self.pos_embed = PatchEmbed(
height=sample_height,
width=sample_width,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_type=None,
)
# 2. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
AllegroTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
# 3. Output projection & norm
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
# 4. Timestep embeddings
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
# 5. Caption projection
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t = self.config.patch_size_t
p = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
if attention_mask is not None and attention_mask.ndim == 4:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
# b, frame+use_image_num, h, w -> a video with images
# b, 1, h, w -> only images
attention_mask = attention_mask.to(hidden_states.dtype)
attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
if attention_mask.numel() > 0:
attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
attention_mask = (
(1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Timestep embeddings
timestep, embedded_timestep = self.adaln_single(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Patch embeddings
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.pos_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
timestep,
attention_mask,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=timestep,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 4. Output normalization & projection
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# 5. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -116,6 +116,7 @@ else: ...@@ -116,6 +116,7 @@ else:
"VersatileDiffusionTextToImagePipeline", "VersatileDiffusionTextToImagePipeline",
] ]
) )
_import_structure["allegro"] = ["AllegroPipeline"]
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [ _import_structure["animatediff"] = [
"AnimateDiffPipeline", "AnimateDiffPipeline",
...@@ -454,6 +455,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -454,6 +455,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * from ..utils.dummy_torch_and_transformers_objects import *
else: else:
from .allegro import AllegroPipeline
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import ( from .animatediff import (
AnimateDiffControlNetPipeline, AnimateDiffControlNetPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_allegro"] = ["AllegroPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_allegro import AllegroPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import math
import re
import urllib.parse as ul
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro
from ...models.embeddings import get_3d_rotary_pos_embed_allegro
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import AllegroPipelineOutput
logger = logging.get_logger(__name__)
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AutoencoderKLAllegro, AllegroPipeline
>>> from diffusers.utils import export_to_video
>>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32)
>>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda")
>>> prompt = (
... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this "
... "location might be a popular spot for docking fishing boats."
... )
>>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0]
>>> export_to_video(video, "output.mp4", fps=15)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class AllegroPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using Allegro.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AllegroAutoEncoderKL3D`]):
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`AllegroTransformer3DModel`]):
A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ "\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
) # noqa
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLAllegro,
transformer: AllegroTransformer3DModel,
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 512,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_videos_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt.
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because T5 can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
num_frames,
height,
width,
callback_on_step_end_tensor_inputs,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if num_frames <= 0:
raise ValueError(f"`num_frames` have to be positive but is {num_frames}.")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if num_frames % 2 == 0:
num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal)
else:
num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width]
return frames
def _prepare_rotary_positional_embeddings(
self,
batch_size: int,
height: int,
width: int,
num_frames: int,
device: torch.device,
):
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
start, stop = (0, 0), (grid_height, grid_width)
freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=(start, stop),
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
interpolation_scale=(
self.transformer.config.interpolation_scale_t,
self.transformer.config.interpolation_scale_h,
self.transformer.config.interpolation_scale_w,
),
)
grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long)
grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long)
grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long)
pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
grid_t, grid_h, grid_w = pos
freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device))
freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device))
freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device))
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 100,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
clean_caption: bool = True,
max_sequence_length: int = 512,
) -> Union[AllegroPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the video generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
usually at the expense of lower video quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
num_frames: (`int`, *optional*, defaults to 88):
The number controls the generated video frames.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated video.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate video. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
max_sequence_length (`int` defaults to `512`):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated videos.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
self.check_inputs(
prompt,
num_frames,
height,
width,
callback_on_step_end_tensor_inputs,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare rotary embeddings
image_rotary_emb = self._prepare_rotary_positional_embeddings(
batch_size, height, width, latents.size(2), device
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
video = video[:, :, :num_frames, :height, :width]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AllegroPipelineOutput(frames=video)
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL
import torch
from diffusers.utils import BaseOutput
@dataclass
class AllegroPipelineOutput(BaseOutput):
r"""
Output class for Allegro pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AsymmetricAutoencoderKL(metaclass=DummyObject): class AsymmetricAutoencoderKL(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -47,6 +62,21 @@ class AutoencoderKL(metaclass=DummyObject): ...@@ -47,6 +62,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoencoderKLAllegro(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLCogVideoX(metaclass=DummyObject): class AutoencoderKLCogVideoX(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class AllegroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AllegroTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AllegroTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 2, 8, 8)
@property
def output_shape(self):
return (4, 2, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"num_layers": 1,
"cross_attention_dim": 16,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"caption_channels": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment