Unverified Commit 0d1d267b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Allegro T2V (#9736)



* update

* refactor transformer part 1

* refactor part 2

* refactor part 3

* make style

* refactor part 4; modeling tests

* make style

* refactor part 5

* refactor part 6

* gradient checkpointing

* pipeline tests (broken atm)

* update

* add coauthor
Co-Authored-By: default avatarHuan Yang <hyang@fastmail.com>

* refactor part 7

* add docs

* make style

* add coauthor
Co-Authored-By: default avatarYiYi Xu <yixu310@gmail.com>

* make fix-copies

* undo unrelated change

* revert changes to embeddings, normalization, transformer

* refactor part 8

* make style

* refactor part 9

* make style

* fix

* apply suggestions from review

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update example

* remove attention mask for self-attention

* update

* copied from

* update

* update

---------
Co-authored-by: default avatarHuan Yang <hyang@fastmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent c5376c56
......@@ -252,6 +252,8 @@
title: SparseControlNetModel
title: ControlNets
- sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
......@@ -300,6 +302,8 @@
- sections:
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/autoencoderkl_allegro
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/asymmetricautoencoderkl
......@@ -318,6 +322,8 @@
sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AllegroTransformer3DModel
A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AllegroTransformer3DModel
vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## AllegroTransformer3DModel
[[autodoc]] AllegroTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLAllegro
The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLAllegro
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLAllegro
[[autodoc]] AutoencoderKLAllegro
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Allegro
[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
The abstract from the paper is:
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## AllegroPipeline
[[autodoc]] AllegroPipeline
- all
- __call__
## AllegroPipelineOutput
[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput
......@@ -77,9 +77,11 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
......@@ -237,6 +239,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AmusedImg2ImgPipeline",
......@@ -556,9 +559,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
......@@ -697,6 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AmusedImg2ImgPipeline,
......
......@@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
......@@ -54,6 +55,7 @@ if is_torch_available():
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
......@@ -81,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
......@@ -97,6 +100,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
......
......@@ -1521,6 +1521,100 @@ class FusedJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
class AllegroAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Apply RoPE if needed
if image_rotary_emb is not None and not attn.is_cross_attention:
from .embeddings import apply_rotary_emb_allegro
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""
......
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
......
This diff is collapsed.
......@@ -564,6 +564,42 @@ def get_3d_rotary_pos_embed(
return cos, sin
def get_3d_rotary_pos_embed_allegro(
embed_dim,
crops_coords,
grid_size,
temporal_size,
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
theta: int = 10000,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(aryan): docs
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 3
dim_h = embed_dim // 3
dim_w = embed_dim // 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(
dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(
dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
)
freqs_w = get_1d_rotary_pos_embed(
dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
)
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
......@@ -684,7 +720,7 @@ def get_1d_rotary_pos_embed(
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
......@@ -743,6 +779,24 @@ def apply_rotary_emb(
return x_out.type_as(x)
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
# TODO(aryan): rewrite
def apply_1d_rope(tokens, pos, cos, sin):
cos = F.embedding(pos, cos)[:, None, :, :]
sin = F.embedding(pos, sin)[:, None, :, :]
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
tokens_rotated = torch.cat((-x2, x1), dim=-1)
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
t, h, w = x.chunk(3, dim=-1)
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
x = torch.cat([t, h, w], dim=-1)
return x
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
......
......@@ -22,10 +22,7 @@ import torch.nn.functional as F
from ..utils import is_torch_version
from .activations import get_activation
from .embeddings import (
CombinedTimestepLabelEmbeddings,
PixArtAlphaCombinedTimestepSizeEmbeddings,
)
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
class AdaLayerNorm(nn.Module):
......@@ -266,6 +263,7 @@ class AdaLayerNormSingle(nn.Module):
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
......
......@@ -14,6 +14,7 @@ if is_torch_available():
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
......
# Copyright 2024 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__)
@maybe_allow_in_graph
class AllegroTransformerBlock(nn.Module):
r"""
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
cross_attention_dim (`int`, defaults to `2304`):
The dimension of the cross attention features.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
only_cross_attention (`bool`, defaults to `False`):
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
):
super().__init__()
# 1. Self Attention
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=AllegroAttnProcessor2_0(),
)
# 2. Cross Attention
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
processor=AllegroAttnProcessor2_0(),
)
# 3. Feed Forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
)
# 4. Scale-shift
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb=None,
) -> torch.Tensor:
# 0. Self-Attention
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
image_rotary_emb=None,
)
hidden_states = attn_output + hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
# TODO(aryan): maybe following line is not required
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
"""
A 3D Transformer model for video-like data.
Args:
patch_size (`int`, defaults to `2`):
The size of spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `96`):
The number of channels in each head.
in_channels (`int`, defaults to `4`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `4`):
The number of channels in the output.
num_layers (`int`, defaults to `32`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
cross_attention_dim (`int`, defaults to `2304`):
The dimension of the cross attention features.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_height (`int`, defaults to `90`):
The height of the input latents.
sample_width (`int`, defaults to `160`):
The width of the input latents.
sample_frames (`int`, defaults to `22`):
The number of frames in the input latents.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
norm_elementwise_affine (`bool`, defaults to `False`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value to use in normalization layers.
caption_channels (`int`, defaults to `4096`):
Number of channels to use for projecting the caption embeddings.
interpolation_scale_h (`float`, defaults to `2.0`):
Scaling factor to apply in 3D positional embeddings across height dimension.
interpolation_scale_w (`float`, defaults to `2.0`):
Scaling factor to apply in 3D positional embeddings across width dimension.
interpolation_scale_t (`float`, defaults to `2.2`):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
@register_to_config
def __init__(
self,
patch_size: int = 2,
patch_size_t: int = 1,
num_attention_heads: int = 24,
attention_head_dim: int = 96,
in_channels: int = 4,
out_channels: int = 4,
num_layers: int = 32,
dropout: float = 0.0,
cross_attention_dim: int = 2304,
attention_bias: bool = True,
sample_height: int = 90,
sample_width: int = 160,
sample_frames: int = 22,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 4096,
interpolation_scale_h: float = 2.0,
interpolation_scale_w: float = 2.0,
interpolation_scale_t: float = 2.2,
):
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
interpolation_scale_t = (
interpolation_scale_t
if interpolation_scale_t is not None
else ((sample_frames - 1) // 16 + 1)
if sample_frames % 2 == 1
else sample_frames // 16
)
interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
# 1. Patch embedding
self.pos_embed = PatchEmbed(
height=sample_height,
width=sample_width,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_type=None,
)
# 2. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
AllegroTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
# 3. Output projection & norm
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
# 4. Timestep embeddings
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
# 5. Caption projection
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t = self.config.patch_size_t
p = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
if attention_mask is not None and attention_mask.ndim == 4:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
# b, frame+use_image_num, h, w -> a video with images
# b, 1, h, w -> only images
attention_mask = attention_mask.to(hidden_states.dtype)
attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
if attention_mask.numel() > 0:
attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
attention_mask = (
(1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Timestep embeddings
timestep, embedded_timestep = self.adaln_single(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Patch embeddings
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.pos_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
timestep,
attention_mask,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=timestep,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 4. Output normalization & projection
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# 5. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -116,6 +116,7 @@ else:
"VersatileDiffusionTextToImagePipeline",
]
)
_import_structure["allegro"] = ["AllegroPipeline"]
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
......@@ -454,6 +455,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .allegro import AllegroPipeline
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import (
AnimateDiffControlNetPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_allegro"] = ["AllegroPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_allegro import AllegroPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL
import torch
from diffusers.utils import BaseOutput
@dataclass
class AllegroPipelineOutput(BaseOutput):
r"""
Output class for Allegro pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
......@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AsymmetricAutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
......@@ -47,6 +62,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLAllegro(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLCogVideoX(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
class AllegroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AllegroTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AllegroTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 2, 8, 8)
@property
def output_shape(self):
return (4, 2, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"num_layers": 1,
"cross_attention_dim": 16,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"caption_channels": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment