Unverified Commit 5e3b7d2d authored by Bubbliiiing's avatar Bubbliiiing Committed by GitHub
Browse files

Add EasyAnimateV5.1 text-to-video, image-to-video, control-to-video generation model (#10626)



* Update EasyAnimate V5.1

* Add docs && add tests && Fix comments problems in transformer3d and vae

* delete comments and remove useless import

* delete process

* Update EXAMPLE_DOC_STRING

* rename transformer file

* make fix-copies

* make style

* refactor pt. 1

* update toctree.yml

* add model tests

* Update layer_norm for norm_added_q and norm_added_k in Attention

* Fix processor problem

* refactor vae

* Fix problem in comments

* refactor tiling; remove einops dependency

* fix docs path

* make fix-copies

* Update src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py

* update _toctree.yml

* fix test

* update

* update

* update

* make fix-copies

* fix tests

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 7513162b
......@@ -290,6 +290,8 @@
title: CogView4Transformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
......@@ -352,6 +354,8 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/autoencoder_kl_wan
......@@ -430,6 +434,8 @@
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/easyanimate
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/control_flux_inpaint
......
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLMagvit
The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLMagvit
vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
```
## AutoencoderKLMagvit
[[autodoc]] AutoencoderKLMagvit
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# EasyAnimateTransformer3DModel
A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
The model can be loaded with the following code snippet.
```python
from diffusers import EasyAnimateTransformer3DModel
transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## EasyAnimateTransformer3DModel
[[autodoc]] EasyAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# EasyAnimate
[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
The description from it's GitHub page:
*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
| checkpoints | recommended inference dtype |
|:---:|:---:|
| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
| checkpoints | recommended inference dtype |
|:---:|:---:|
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
There are two official EasyAnimate checkpoints available for control-to-video.
| checkpoints | recommended inference dtype |
|:---:|:---:|
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
For the EasyAnimateV5.1 series:
- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
from diffusers.utils import export_to_video
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
"alibaba-pai/EasyAnimateV5.1-12b-zh",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipeline = EasyAnimatePipeline.from_pretrained(
"alibaba-pai/EasyAnimateV5.1-12b-zh",
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
prompt = "A cat walks on the grass, realistic style."
negative_prompt = "bad detailed"
video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
export_to_video(video, "cat.mp4", fps=8)
```
## EasyAnimatePipeline
[[autodoc]] EasyAnimatePipeline
- all
- __call__
## EasyAnimatePipelineOutput
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput
......@@ -94,6 +94,7 @@ else:
"AutoencoderKLCogVideoX",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
......@@ -109,6 +110,7 @@ else:
"ControlNetUnionModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
......@@ -293,6 +295,9 @@ else:
"CogView4Pipeline",
"ConsisIDPipeline",
"CycleDiffusionPipeline",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
......@@ -620,6 +625,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
......@@ -635,6 +641,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetUnionModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
......@@ -798,6 +805,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Pipeline,
ConsisIDPipeline,
CycleDiffusionPipeline,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
......
......@@ -33,6 +33,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
......@@ -72,6 +73,7 @@ if is_torch_available():
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
......@@ -109,6 +111,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
......@@ -144,6 +147,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsisIDTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoTransformer3DModel,
......
......@@ -274,7 +274,10 @@ class Attention(nn.Module):
self.to_add_out = None
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
if qk_norm == "layer_norm":
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
......
......@@ -5,6 +5,7 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan
......
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class EasyAnimateCausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
):
# Ensure kernel_size, stride, and dilation are tuples of length 3
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
stride = stride if isinstance(stride, tuple) else (stride,) * 3
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
# Unpack kernel size, stride, and dilation for temporal, height, and width dimensions
t_ks, h_ks, w_ks = kernel_size
self.t_stride, h_stride, w_stride = stride
t_dilation, h_dilation, w_dilation = dilation
# Calculate padding for temporal dimension to maintain causality
t_pad = (t_ks - 1) * t_dilation
# Calculate padding for height and width dimensions based on the padding parameter
if padding is None:
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
elif isinstance(padding, int):
h_pad = w_pad = padding
else:
assert NotImplementedError
# Store temporal padding and initialize flags and previous features cache
self.temporal_padding = t_pad
self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
self.prev_features = None
# Initialize the parent class with modified padding
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=(0, h_pad, w_pad),
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
def _clear_conv_cache(self):
del self.prev_features
self.prev_features = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Ensure input tensor is of the correct type
dtype = hidden_states.dtype
if self.prev_features is None:
# Pad the input tensor in the temporal dimension to maintain causality
hidden_states = F.pad(
hidden_states,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
hidden_states = hidden_states.to(dtype=dtype)
# Clear cache before processing and store previous features for causality
self._clear_conv_cache()
self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone()
# Process the input tensor in chunks along the temporal dimension
num_frames = hidden_states.size(2)
outputs = []
i = 0
while i + self.temporal_padding + 1 <= num_frames:
out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1])
i += self.t_stride
outputs.append(out)
return torch.concat(outputs, 2)
else:
# Concatenate previous features with the input tensor for continuous temporal processing
if self.t_stride == 2:
hidden_states = torch.concat(
[self.prev_features[:, :, -(self.temporal_padding - 1) :], hidden_states], dim=2
)
else:
hidden_states = torch.concat([self.prev_features, hidden_states], dim=2)
hidden_states = hidden_states.to(dtype=dtype)
# Clear cache and update previous features
self._clear_conv_cache()
self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone()
# Process the concatenated tensor in chunks along the temporal dimension
num_frames = hidden_states.size(2)
outputs = []
i = 0
while i + self.temporal_padding + 1 <= num_frames:
out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1])
i += self.t_stride
outputs.append(out)
return torch.concat(outputs, 2)
class EasyAnimateResidualBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
non_linearity: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
spatial_group_norm: bool = True,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
):
super().__init__()
self.output_scale_factor = output_scale_factor
# Group normalization for input tensor
self.norm1 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=in_channels,
eps=norm_eps,
affine=True,
)
self.nonlinearity = get_activation(non_linearity)
self.conv1 = EasyAnimateCausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = EasyAnimateCausalConv3d(out_channels, out_channels, kernel_size=3)
if in_channels != out_channels:
self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
self.spatial_group_norm = spatial_group_norm
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(hidden_states)
if self.spatial_group_norm:
batch_size = hidden_states.size(0)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
hidden_states = self.norm1(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
0, 2, 1, 3, 4
) # [B * T, C, H, W] -> [B, C, T, H, W]
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.spatial_group_norm:
batch_size = hidden_states.size(0)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
hidden_states = self.norm2(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
0, 2, 1, 3, 4
) # [B * T, C, H, W] -> [B, C, T, H, W]
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
return (hidden_states + shortcut) / self.output_scale_factor
class EasyAnimateDownsampler3D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: tuple = (2, 2, 2)):
super().__init__()
self.conv = EasyAnimateCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1))
hidden_states = self.conv(hidden_states)
return hidden_states
class EasyAnimateUpsampler3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
temporal_upsample: bool = False,
spatial_group_norm: bool = True,
):
super().__init__()
out_channels = out_channels or in_channels
self.temporal_upsample = temporal_upsample
self.spatial_group_norm = spatial_group_norm
self.conv = EasyAnimateCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.prev_features = None
def _clear_conv_cache(self):
del self.prev_features
self.prev_features = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.interpolate(hidden_states, scale_factor=(1, 2, 2), mode="nearest")
hidden_states = self.conv(hidden_states)
if self.temporal_upsample:
if self.prev_features is None:
self.prev_features = hidden_states
else:
hidden_states = F.interpolate(
hidden_states,
scale_factor=(2, 1, 1),
mode="trilinear" if not self.spatial_group_norm else "nearest",
)
return hidden_states
class EasyAnimateDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
spatial_group_norm: bool = True,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
add_temporal_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
EasyAnimateResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
spatial_group_norm=spatial_group_norm,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_downsample and add_temporal_downsample:
self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(2, 2, 2))
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 2
elif add_downsample and not add_temporal_downsample:
self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(1, 2, 2))
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 1
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for conv in self.convs:
hidden_states = conv(hidden_states)
if self.downsampler is not None:
hidden_states = self.downsampler(hidden_states)
return hidden_states
class EasyAnimateUpBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
spatial_group_norm: bool = False,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
add_temporal_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
EasyAnimateResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
spatial_group_norm=spatial_group_norm,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_upsample:
self.upsampler = EasyAnimateUpsampler3D(
in_channels,
in_channels,
temporal_upsample=add_temporal_upsample,
spatial_group_norm=spatial_group_norm,
)
else:
self.upsampler = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for conv in self.convs:
hidden_states = conv(hidden_states)
if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states)
return hidden_states
class EasyAnimateMidBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
spatial_group_norm: bool = True,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
):
super().__init__()
norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32)
self.convs = nn.ModuleList(
[
EasyAnimateResidualBlock3D(
in_channels=in_channels,
out_channels=in_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
spatial_group_norm=spatial_group_norm,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
]
)
for _ in range(num_layers - 1):
self.convs.append(
EasyAnimateResidualBlock3D(
in_channels=in_channels,
out_channels=in_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
spatial_group_norm=spatial_group_norm,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.convs[0](hidden_states)
for resnet in self.convs[1:]:
hidden_states = resnet(hidden_states)
return hidden_states
class EasyAnimateEncoder(nn.Module):
r"""
Causal encoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int = 3,
out_channels: int = 8,
down_block_types: Tuple[str, ...] = (
"SpatialDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
),
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
spatial_group_norm: bool = False,
):
super().__init__()
# 1. Input convolution
self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
# 2. Down blocks
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channels = output_channels
output_channels = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if down_block_type == "SpatialDownBlock3D":
down_block = EasyAnimateDownBlock3D(
in_channels=input_channels,
out_channels=output_channels,
num_layers=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
spatial_group_norm=spatial_group_norm,
add_downsample=not is_final_block,
add_temporal_downsample=False,
)
elif down_block_type == "SpatialTemporalDownBlock3D":
down_block = EasyAnimateDownBlock3D(
in_channels=input_channels,
out_channels=output_channels,
num_layers=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
spatial_group_norm=spatial_group_norm,
add_downsample=not is_final_block,
add_temporal_downsample=True,
)
else:
raise ValueError(f"Unknown up block type: {down_block_type}")
self.down_blocks.append(down_block)
# 3. Middle block
self.mid_block = EasyAnimateMidBlock3d(
in_channels=block_out_channels[-1],
num_layers=layers_per_block,
act_fn=act_fn,
spatial_group_norm=spatial_group_norm,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
dropout=0,
output_scale_factor=1,
)
# 4. Output normalization & convolution
self.spatial_group_norm = spatial_group_norm
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1],
num_groups=norm_num_groups,
eps=1e-6,
)
self.conv_act = get_activation(act_fn)
# Initialize the output convolution layer
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = EasyAnimateCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states: (B, C, T, H, W)
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
else:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
if self.spatial_group_norm:
batch_size = hidden_states.size(0)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
else:
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class EasyAnimateDecoder(nn.Module):
r"""
Causal decoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int = 8,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = (
"SpatialUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
),
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
spatial_group_norm: bool = False,
):
super().__init__()
# 1. Input convolution
self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3)
# 2. Middle block
self.mid_block = EasyAnimateMidBlock3d(
in_channels=block_out_channels[-1],
num_layers=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
dropout=0,
output_scale_factor=1,
)
# 3. Up blocks
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channels = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
input_channels = output_channels
output_channels = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# Create and append up block to up_blocks
if up_block_type == "SpatialUpBlock3D":
up_block = EasyAnimateUpBlock3d(
in_channels=input_channels,
out_channels=output_channels,
num_layers=layers_per_block + 1,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
spatial_group_norm=spatial_group_norm,
add_upsample=not is_final_block,
add_temporal_upsample=False,
)
elif up_block_type == "SpatialTemporalUpBlock3D":
up_block = EasyAnimateUpBlock3d(
in_channels=input_channels,
out_channels=output_channels,
num_layers=layers_per_block + 1,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
spatial_group_norm=spatial_group_norm,
add_upsample=not is_final_block,
add_temporal_upsample=True,
)
else:
raise ValueError(f"Unknown up block type: {up_block_type}")
self.up_blocks.append(up_block)
# Output normalization and activation
self.spatial_group_norm = spatial_group_norm
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0],
num_groups=norm_num_groups,
eps=1e-6,
)
self.conv_act = get_activation(act_fn)
# Output convolution layer
self.conv_out = EasyAnimateCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states: (B, C, T, H, W)
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = up_block(hidden_states)
if self.spatial_group_norm:
batch_size = hidden_states.size(0)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
0, 2, 1, 3, 4
) # [B * T, C, H, W] -> [B, C, T, H, W]
else:
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
model is used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
latent_channels: int = 16,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
down_block_types: Tuple[str, ...] = [
"SpatialDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
],
up_block_types: Tuple[str, ...] = [
"SpatialUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
],
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
scaling_factor: float = 0.7125,
spatial_group_norm: bool = True,
):
super().__init__()
# Initialize the encoder
self.encoder = EasyAnimateEncoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=True,
spatial_group_norm=spatial_group_norm,
)
# Initialize the decoder
self.decoder = EasyAnimateDecoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
spatial_group_norm=spatial_group_norm,
)
# Initialize convolution layers for quantization and post-quantization
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
self.temporal_compression_ratio = 2 ** (len(block_out_channels) - 2)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered.
self.use_framewise_encoding = False
self.use_framewise_decoding = False
# Assign mini-batch sizes for encoder and decoder
self.num_sample_frames_batch_size = 4
self.num_latent_frames_batch_size = 1
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 4
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8
def _clear_conv_cache(self):
# Clear cache for convolutional layers if needed
for name, module in self.named_modules():
if isinstance(module, EasyAnimateCausalConv3d):
module._clear_conv_cache()
if isinstance(module, EasyAnimateUpsampler3D):
module._clear_conv_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.use_framewise_decoding = True
self.use_framewise_encoding = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def _encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_height or x.shape[-2] > self.tile_sample_min_width):
return self.tiled_encode(x, return_dict=return_dict)
first_frames = self.encoder(x[:, :, :1, :, :])
h = [first_frames]
for i in range(1, x.shape[2], self.num_sample_frames_batch_size):
next_frames = self.encoder(x[:, :, i : i + self.num_sample_frames_batch_size, :, :])
h.append(next_frames)
h = torch.cat(h, dim=2)
moments = self.quant_conv(h)
self._clear_conv_cache()
return moments
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
return self.tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
# Process the first frame and save the result
first_frames = self.decoder(z[:, :, :1, :, :])
# Initialize the list to store the processed frames, starting with the first frame
dec = [first_frames]
# Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
for i in range(1, z.shape[2], self.num_latent_frames_batch_size):
next_frames = self.decoder(z[:, :, i : i + self.num_latent_frames_batch_size, :, :])
dec.append(next_frames)
# Concatenate all processed frames along the channel dimension
dec = torch.cat(dec, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
self._clear_conv_cache()
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
first_frames = self.encoder(tile[:, :, 0:1, :, :])
tile_h = [first_frames]
for k in range(1, num_frames, self.num_sample_frames_batch_size):
next_frames = self.encoder(tile[:, :, k : k + self.num_sample_frames_batch_size, :, :])
tile_h.append(next_frames)
tile = torch.cat(tile_h, dim=2)
tile = self.quant_conv(tile)
self._clear_conv_cache()
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :latent_height, :latent_width])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
tile = self.post_quant_conv(tile)
# Process the first frame and save the result
first_frames = self.decoder(tile[:, :, :1, :, :])
# Initialize the list to store the processed frames, starting with the first frame
tile_dec = [first_frames]
# Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
for k in range(1, num_frames, self.num_latent_frames_batch_size):
next_frames = self.decoder(tile[:, :, k : k + self.num_latent_frames_batch_size, :, :])
tile_dec.append(next_frames)
# Concatenate all processed frames along the channel dimension
decoded = torch.cat(tile_dec, dim=2)
self._clear_conv_cache()
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
......@@ -19,6 +19,7 @@ if is_torch_available():
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
......
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class EasyAnimateLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "fp32_layer_norm",
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale.unsqueeze(1)) + enc_shift.unsqueeze(
1
)
return hidden_states, encoder_hidden_states, gate, enc_gate
class EasyAnimateRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
super().__init__()
self.patch_size = patch_size
self.rope_dim = rope_dim
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
bs, c, num_frames, grid_height, grid_width = hidden_states.size()
grid_height = grid_height // self.patch_size
grid_width = grid_width // self.patch_size
base_size_width = 90 // self.patch_size
base_size_height = 60 // self.patch_size
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
image_rotary_emb = get_3d_rotary_pos_embed(
self.rope_dim,
grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=hidden_states.size(2),
use_real=True,
)
return image_rotary_emb
class EasyAnimateAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the EasyAnimateTransformer3DModel model.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=2)
key = torch.cat([encoder_key, key], dim=2)
value = torch.cat([encoder_value, value], dim=2)
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
else:
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class EasyAnimateTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
qk_norm: bool = True,
after_norm: bool = False,
norm_type: str = "fp32_layer_norm",
is_mmdit_block: bool = True,
):
super().__init__()
# Attention Part
self.norm1 = EasyAnimateLayerNormZero(
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
added_proj_bias=True,
added_kv_proj_dim=dim if is_mmdit_block else None,
context_pre_only=False if is_mmdit_block else None,
processor=EasyAnimateAttnProcessor2_0(),
)
# FFN Part
self.norm2 = EasyAnimateLayerNormZero(
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.txt_ff = None
if is_mmdit_block:
self.txt_ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.norm3 = None
if after_norm:
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Attention
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
# 2. Feed-forward
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
if self.norm3 is not None:
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
if self.txt_ff is not None:
norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
else:
norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
else:
norm_hidden_states = self.ff(norm_hidden_states)
if self.txt_ff is not None:
norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
else:
norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_ff.unsqueeze(1) * norm_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_ff.unsqueeze(1) * norm_encoder_hidden_states
return hidden_states, encoder_hidden_states
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).
Parameters:
num_attention_heads (`int`, defaults to `48`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
mmdit_layers (`int`, defaults to `1000`):
The number of layers of Multi Modal Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use elementwise affine in normalization layers.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_position_encoding_type (`str`, defaults to `3d_rope`):
Type of time position encoding.
after_norm (`bool`, defaults to `False`):
Flag to apply normalization after.
resize_inpaint_mask_directly (`bool`, defaults to `True`):
Flag to resize inpaint mask directly.
enable_text_attention_mask (`bool`, defaults to `True`):
Flag to enable text attention mask.
add_noise_in_inpaint_model (`bool`, defaults to `False`):
Flag to add noise in inpaint model.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["EasyAnimateTransformerBlock"]
_skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]
@register_to_config
def __init__(
self,
num_attention_heads: int = 48,
attention_head_dim: int = 64,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
patch_size: Optional[int] = None,
sample_width: int = 90,
sample_height: int = 60,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
freq_shift: int = 0,
num_layers: int = 48,
mmdit_layers: int = 48,
dropout: float = 0.0,
time_embed_dim: int = 512,
add_norm_text_encoder: bool = False,
text_embed_dim: int = 3584,
text_embed_dim_t5: int = None,
norm_eps: float = 1e-5,
norm_elementwise_affine: bool = True,
flip_sin_to_cos: bool = True,
time_position_encoding_type: str = "3d_rope",
after_norm=False,
resize_inpaint_mask_directly: bool = True,
enable_text_attention_mask: bool = True,
add_noise_in_inpaint_model: bool = True,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
# 1. Timestep embedding
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
# 2. Patch embedding
self.proj = nn.Conv2d(
in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
)
# 3. Text refined embedding
self.text_proj = None
self.text_proj_t5 = None
if not add_norm_text_encoder:
self.text_proj = nn.Linear(text_embed_dim, inner_dim)
if text_embed_dim_t5 is not None:
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, inner_dim)
else:
self.text_proj = nn.Sequential(
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim, inner_dim)
)
if text_embed_dim_t5 is not None:
self.text_proj_t5 = nn.Sequential(
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim_t5, inner_dim)
)
# 4. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
EasyAnimateTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
after_norm=after_norm,
is_mmdit_block=True if _ < mmdit_layers else False,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output norm & projection
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
timestep_cond: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, channels, video_length, height, width = hidden_states.size()
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
# 1. Time embedding
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
temb = self.time_embedding(temb, timestep_cond)
image_rotary_emb = self.rope_embedding(hidden_states)
# 2. Patch embedding
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 1)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, F, H, W] -> [BF, C, H, W]
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
0, 2, 1, 3, 4
) # [BF, C, H, W] -> [B, F, C, H, W]
hidden_states = hidden_states.flatten(2, 4).transpose(1, 2) # [B, F, C, H, W] -> [B, FHW, C]
# 3. Text embedding
encoder_hidden_states = self.text_proj(encoder_hidden_states)
if encoder_hidden_states_t5 is not None:
encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
# 4. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
hidden_states = self.norm_final(hidden_states)
# 5. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -216,6 +216,11 @@ else:
"IFPipeline",
"IFSuperResolutionPipeline",
]
_import_structure["easyanimate"] = [
"EasyAnimatePipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["kandinsky"] = [
......@@ -546,6 +551,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .easyanimate import (
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
)
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"]
_import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"]
_import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_easyanimate import EasyAnimatePipeline
from .pipeline_easyanimate_control import EasyAnimateControlPipeline
from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import (
BertModel,
BertTokenizer,
Qwen2Tokenizer,
Qwen2VLForConditionalGeneration,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import EasyAnimatePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import EasyAnimatePipeline
>>> from diffusers.utils import export_to_video
>>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
>>> pipe = EasyAnimatePipeline.from_pretrained(
... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
... "atmosphere of this unique musical performance."
... )
>>> sample_size = (512, 512)
>>> video = pipe(
... prompt=prompt,
... guidance_scale=6,
... negative_prompt="bad detailed",
... height=sample_size[0],
... width=sample_size[1],
... num_inference_steps=50,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class EasyAnimatePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using EasyAnimate.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
Args:
vae ([`AutoencoderKLMagvit`]):
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
transformer ([`EasyAnimateTransformer3DModel`]):
The EasyAnimate model designed by EasyAnimate Team.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKLMagvit,
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
transformer: EasyAnimateTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.enable_text_attention_mask = (
self.transformer.config.enable_text_attention_mask
if getattr(self, "transformer", None) is not None
else True
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
"""
dtype = dtype or self.text_encoder.dtype
device = device or self.text_encoder.device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
if isinstance(prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _prompt}],
}
for _prompt in prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
prompt_embeds = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
if negative_prompt is not None and isinstance(negative_prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": negative_prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _negative_prompt}],
}
for _negative_prompt in negative_prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
negative_prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
negative_prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=negative_prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
height // self.vae_spatial_compression_ratio,
width // self.vae_spatial_compression_ratio,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
if hasattr(self.scheduler, "init_noise_sigma"):
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 49,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
timesteps: Optional[List[int]] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
):
r"""
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
Examples:
prompt (`str` or `List[str]`, *optional*):
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
num_frames (`int`, *optional*):
Length of the generated video (in frames).
height (`int`, *optional*):
Height of the generated image in pixels.
width (`int`, *optional*):
Width of the generated image in pixels.
num_inference_steps (`int`, *optional*, defaults to 50):
Number of denoising steps during generation. More steps generally yield higher quality images but slow
down inference.
guidance_scale (`float`, *optional*, defaults to 5.0):
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
negative_prompt (`str` or `List[str]`, *optional*):
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images to generate for each prompt.
eta (`float`, *optional*, defaults to 0.0):
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A generator to ensure reproducibility in image generation.
latents (`torch.Tensor`, *optional*):
Predefined latent tensors to condition generation.
prompt_embeds (`torch.Tensor`, *optional*):
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Embeddings for negative prompts. Overrides string inputs if defined.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the primary prompt embeddings.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for negative prompt embeddings.
output_type (`str`, *optional*, defaults to "latent"):
Format of the generated output, either as a PIL image or as a NumPy array.
return_dict (`bool`, *optional*, defaults to `True`):
If `True`, returns a structured output. Otherwise returns a simple tuple.
callback_on_step_end (`Callable`, *optional*):
Functions called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor names to be included in callback function calls.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Adjusts noise levels based on guidance scale.
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
Original dimensions of the output.
target_size (`Tuple[int, int]`, *optional*):
Desired output dimensions for calculations.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
Coordinates for cropping.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = int((height // 16) * 16)
width = int((width // 16) * 16)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = self.transformer.dtype
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# 4. Prepare timesteps
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, mu=1
)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if hasattr(self.scheduler, "scale_model_input"):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
return_dict=False,
)[0]
if noise_pred.size()[1] != self.vae.config.latent_channels:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = 1 / self.vae.config.scaling_factor * latents
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return EasyAnimatePipelineOutput(frames=video)
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (
BertModel,
BertTokenizer,
Qwen2Tokenizer,
Qwen2VLForConditionalGeneration,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import EasyAnimatePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import EasyAnimateControlPipeline
>>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
>>> from diffusers.utils import export_to_video, load_video
>>> pipe = EasyAnimateControlPipeline.from_pretrained(
... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> control_video = load_video(
... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
... )
>>> prompt = (
... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
... "releasing their fragrances, creating a relaxed and joyful atmosphere."
... )
>>> sample_size = (672, 384)
>>> num_frames = 49
>>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
>>> video = pipe(
... prompt,
... num_frames=num_frames,
... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
... height=sample_size[0],
... width=sample_size[1],
... control_video=input_video,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
def preprocess_image(image, sample_size):
"""
Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
"""
if isinstance(image, torch.Tensor):
# If input is a tensor, assume it's in CHW format and resize using interpolation
image = torch.nn.functional.interpolate(
image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
).squeeze(0)
elif isinstance(image, Image.Image):
# If input is a PIL image, resize and convert to numpy array
image = image.resize((sample_size[1], sample_size[0]))
image = np.array(image)
elif isinstance(image, np.ndarray):
# If input is a numpy array, resize using PIL
image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
image = np.array(image)
else:
raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
# Convert to tensor if not already
if not isinstance(image, torch.Tensor):
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
return image
def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
if input_video is not None:
# Convert each frame in the list to tensor
input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
# Stack all frames into a single tensor (F, C, H, W)
input_video = torch.stack(input_video)[:num_frames]
# Add batch dimension (B, F, C, H, W)
input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
if validation_video_mask is not None:
# Handle mask input
validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
# Adjust mask dimensions to match video
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
else:
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, :] = 255
else:
input_video, input_video_mask = None, None
if ref_image is not None:
# Convert reference image to tensor
ref_image = preprocess_image(ref_image, size=sample_size)
ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
else:
ref_image = None
return input_video, input_video_mask, ref_image
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Resize mask information in magvit
def resize_mask(mask, latent, process_first_frame_only=True):
latent_size = latent.size()
if process_first_frame_only:
target_size = list(latent_size[2:])
target_size[0] = 1
first_frame_resized = F.interpolate(
mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
)
target_size = list(latent_size[2:])
target_size[0] = target_size[0] - 1
if target_size[0] != 0:
remaining_frames_resized = F.interpolate(
mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
)
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
else:
resized_mask = first_frame_resized
else:
target_size = list(latent_size[2:])
resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
return resized_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class EasyAnimateControlPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using EasyAnimate.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
Args:
vae ([`AutoencoderKLMagvit`]):
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
transformer ([`EasyAnimateTransformer3DModel`]):
The EasyAnimate model designed by EasyAnimate Team.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKLMagvit,
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
transformer: EasyAnimateTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.enable_text_attention_mask = (
self.transformer.config.enable_text_attention_mask
if getattr(self, "transformer", None) is not None
else True
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_spatial_compression_ratio,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
"""
dtype = dtype or self.text_encoder.dtype
device = device or self.text_encoder.device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
if isinstance(prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _prompt}],
}
for _prompt in prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
prompt_embeds = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
if negative_prompt is not None and isinstance(negative_prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": negative_prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _negative_prompt}],
}
for _negative_prompt in negative_prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
negative_prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
negative_prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=negative_prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
height // self.vae_spatial_compression_ratio,
width // self.vae_spatial_compression_ratio,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
if hasattr(self.scheduler, "init_noise_sigma"):
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_control_latents(
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the control to latents shape as we concatenate the control to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
if control is not None:
control = control.to(device=device, dtype=dtype)
bs = 1
new_control = []
for i in range(0, control.shape[0], bs):
control_bs = control[i : i + bs]
control_bs = self.vae.encode(control_bs)[0]
control_bs = control_bs.mode()
new_control.append(control_bs)
control = torch.cat(new_control, dim=0)
control = control * self.vae.config.scaling_factor
if control_image is not None:
control_image = control_image.to(device=device, dtype=dtype)
bs = 1
new_control_pixel_values = []
for i in range(0, control_image.shape[0], bs):
control_pixel_values_bs = control_image[i : i + bs]
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
control_pixel_values_bs = control_pixel_values_bs.mode()
new_control_pixel_values.append(control_pixel_values_bs)
control_image_latents = torch.cat(new_control_pixel_values, dim=0)
control_image_latents = control_image_latents * self.vae.config.scaling_factor
else:
control_image_latents = None
return control, control_image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 49,
height: Optional[int] = 512,
width: Optional[int] = 512,
control_video: Union[torch.FloatTensor] = None,
control_camera_video: Union[torch.FloatTensor] = None,
ref_image: Union[torch.FloatTensor] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
timesteps: Optional[List[int]] = None,
):
r"""
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
Examples:
prompt (`str` or `List[str]`, *optional*):
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
num_frames (`int`, *optional*):
Length of the generated video (in frames).
height (`int`, *optional*):
Height of the generated image in pixels.
width (`int`, *optional*):
Width of the generated image in pixels.
num_inference_steps (`int`, *optional*, defaults to 50):
Number of denoising steps during generation. More steps generally yield higher quality images but slow
down inference.
guidance_scale (`float`, *optional*, defaults to 5.0):
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
negative_prompt (`str` or `List[str]`, *optional*):
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images to generate for each prompt.
eta (`float`, *optional*, defaults to 0.0):
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A generator to ensure reproducibility in image generation.
latents (`torch.Tensor`, *optional*):
Predefined latent tensors to condition generation.
prompt_embeds (`torch.Tensor`, *optional*):
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Embeddings for negative prompts. Overrides string inputs if defined.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the primary prompt embeddings.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for negative prompt embeddings.
output_type (`str`, *optional*, defaults to "latent"):
Format of the generated output, either as a PIL image or as a NumPy array.
return_dict (`bool`, *optional*, defaults to `True`):
If `True`, returns a structured output. Otherwise returns a simple tuple.
callback_on_step_end (`Callable`, *optional*):
Functions called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor names to be included in callback function calls.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Adjusts noise levels based on guidance scale.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = int((height // 16) * 16)
width = int((width // 16) * 16)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = self.transformer.dtype
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
text_encoder_index=0,
)
# 4. Prepare timesteps
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, mu=1
)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
latents,
)
if control_camera_video is not None:
control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
control_video_latents = control_video_latents * 6
control_latents = (
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
).to(device, dtype)
elif control_video is not None:
batch_size, channels, num_frames, height_video, width_video = control_video.shape
control_video = self.image_processor.preprocess(
control_video.permute(0, 2, 1, 3, 4).reshape(
batch_size * num_frames, channels, height_video, width_video
),
height=height,
width=width,
)
control_video = control_video.to(dtype=torch.float32)
control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(
0, 2, 1, 3, 4
)
control_video_latents = self.prepare_control_latents(
None,
control_video,
batch_size,
height,
width,
dtype,
device,
generator,
self.do_classifier_free_guidance,
)[1]
control_latents = (
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
).to(device, dtype)
else:
control_video_latents = torch.zeros_like(latents).to(device, dtype)
control_latents = (
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
).to(device, dtype)
if ref_image is not None:
batch_size, channels, num_frames, height_video, width_video = ref_image.shape
ref_image = self.image_processor.preprocess(
ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
height=height,
width=width,
)
ref_image = ref_image.to(dtype=torch.float32)
ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
ref_image_latents = self.prepare_control_latents(
None,
ref_image,
batch_size,
height,
width,
prompt_embeds.dtype,
device,
generator,
self.do_classifier_free_guidance,
)[1]
ref_image_latents_conv_in = torch.zeros_like(latents)
if latents.size()[2] != 1:
ref_image_latents_conv_in[:, :, :1] = ref_image_latents
ref_image_latents_conv_in = (
torch.cat([ref_image_latents_conv_in] * 2)
if self.do_classifier_free_guidance
else ref_image_latents_conv_in
).to(device, dtype)
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
else:
ref_image_latents_conv_in = torch.zeros_like(latents)
ref_image_latents_conv_in = (
torch.cat([ref_image_latents_conv_in] * 2)
if self.do_classifier_free_guidance
else ref_image_latents_conv_in
).to(device, dtype)
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
# To latents.device
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if hasattr(self.scheduler, "scale_model_input"):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
control_latents=control_latents,
return_dict=False,
)[0]
if noise_pred.size()[1] != self.vae.config.latent_channels:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# Convert to tensor
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return EasyAnimatePipelineOutput(frames=video)
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (
BertModel,
BertTokenizer,
Qwen2Tokenizer,
Qwen2VLForConditionalGeneration,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import EasyAnimatePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import EasyAnimateInpaintPipeline
>>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent
>>> from diffusers.utils import export_to_video, load_image
>>> pipe = EasyAnimateInpaintPipeline.from_pretrained(
... "alibaba-pai/EasyAnimateV5.1-12b-zh-InP-diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
>>> validation_image_start = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
... )
>>> validation_image_end = None
>>> sample_size = (448, 576)
>>> num_frames = 49
>>> input_video, input_video_mask = get_image_to_video_latent(
... [validation_image_start], validation_image_end, num_frames, sample_size
... )
>>> video = pipe(
... prompt,
... num_frames=num_frames,
... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
... height=sample_size[0],
... width=sample_size[1],
... video=input_video,
... mask_video=input_video_mask,
... )
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
```
"""
def preprocess_image(image, sample_size):
"""
Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
"""
if isinstance(image, torch.Tensor):
# If input is a tensor, assume it's in CHW format and resize using interpolation
image = torch.nn.functional.interpolate(
image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
).squeeze(0)
elif isinstance(image, Image.Image):
# If input is a PIL image, resize and convert to numpy array
image = image.resize((sample_size[1], sample_size[0]))
image = np.array(image)
elif isinstance(image, np.ndarray):
# If input is a numpy array, resize using PIL
image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
image = np.array(image)
else:
raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
# Convert to tensor if not already
if not isinstance(image, torch.Tensor):
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
return image
def get_image_to_video_latent(validation_image_start, validation_image_end, num_frames, sample_size):
"""
Generate latent representations for video from start and end images. Inputs can be PIL.Image, numpy.ndarray, or
torch.Tensor.
"""
input_video = None
input_video_mask = None
if validation_image_start is not None:
# Preprocess the starting image(s)
if isinstance(validation_image_start, list):
image_start = [preprocess_image(img, sample_size) for img in validation_image_start]
else:
image_start = preprocess_image(validation_image_start, sample_size)
# Create video tensor from the starting image(s)
if isinstance(image_start, list):
start_video = torch.cat(
[img.unsqueeze(1).unsqueeze(0) for img in image_start],
dim=2,
)
input_video = torch.tile(start_video[:, :, :1], [1, 1, num_frames, 1, 1])
input_video[:, :, : len(image_start)] = start_video
else:
input_video = torch.tile(
image_start.unsqueeze(1).unsqueeze(0),
[1, 1, num_frames, 1, 1],
)
# Normalize input video (already normalized in preprocess_image)
# Create mask for the input video
input_video_mask = torch.zeros_like(input_video[:, :1])
if isinstance(image_start, list):
input_video_mask[:, :, len(image_start) :] = 255
else:
input_video_mask[:, :, 1:] = 255
# Handle ending image(s) if provided
if validation_image_end is not None:
if isinstance(validation_image_end, list):
image_end = [preprocess_image(img, sample_size) for img in validation_image_end]
end_video = torch.cat(
[img.unsqueeze(1).unsqueeze(0) for img in image_end],
dim=2,
)
input_video[:, :, -len(end_video) :] = end_video
input_video_mask[:, :, -len(image_end) :] = 0
else:
image_end = preprocess_image(validation_image_end, sample_size)
input_video[:, :, -1:] = image_end.unsqueeze(1).unsqueeze(0)
input_video_mask[:, :, -1:] = 0
elif validation_image_start is None:
# If no starting image is provided, initialize empty tensors
input_video = torch.zeros([1, 3, num_frames, sample_size[0], sample_size[1]])
input_video_mask = torch.ones([1, 1, num_frames, sample_size[0], sample_size[1]]) * 255
return input_video, input_video_mask
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Resize mask information in magvit
def resize_mask(mask, latent, process_first_frame_only=True):
latent_size = latent.size()
if process_first_frame_only:
target_size = list(latent_size[2:])
target_size[0] = 1
first_frame_resized = F.interpolate(
mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
)
target_size = list(latent_size[2:])
target_size[0] = target_size[0] - 1
if target_size[0] != 0:
remaining_frames_resized = F.interpolate(
mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
)
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
else:
resized_mask = first_frame_resized
else:
target_size = list(latent_size[2:])
resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
return resized_mask
## Add noise to reference video
def add_noise_to_reference_video(image, ratio=None, generator=None):
if ratio is None:
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
sigma = torch.exp(sigma).to(image.dtype)
else:
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
if generator is not None:
image_noise = (
torch.randn(image.size(), generator=generator, dtype=image.dtype, device=image.device)
* sigma[:, None, None, None, None]
)
else:
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
image_noise = torch.where(image == -1, torch.zeros_like(image), image_noise)
image = image + image_noise
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class EasyAnimateInpaintPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using EasyAnimate.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
Args:
vae ([`AutoencoderKLMagvit`]):
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
transformer ([`EasyAnimateTransformer3DModel`]):
The EasyAnimate model designed by EasyAnimate Team.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKLMagvit,
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
transformer: EasyAnimateTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.enable_text_attention_mask = (
self.transformer.config.enable_text_attention_mask
if getattr(self, "transformer", None) is not None
else True
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_spatial_compression_ratio,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
"""
dtype = dtype or self.text_encoder.dtype
device = device or self.text_encoder.device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
if isinstance(prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _prompt}],
}
for _prompt in prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
prompt_embeds = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
if negative_prompt is not None and isinstance(negative_prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": negative_prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _negative_prompt}],
}
for _negative_prompt in negative_prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
negative_prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
negative_prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=negative_prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
device,
generator,
do_classifier_free_guidance,
noise_aug_strength,
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
if mask is not None:
mask = mask.to(device=device, dtype=dtype)
new_mask = []
bs = 1
for i in range(0, mask.shape[0], bs):
mask_bs = mask[i : i + bs]
mask_bs = self.vae.encode(mask_bs)[0]
mask_bs = mask_bs.mode()
new_mask.append(mask_bs)
mask = torch.cat(new_mask, dim=0)
mask = mask * self.vae.config.scaling_factor
if masked_image is not None:
masked_image = masked_image.to(device=device, dtype=dtype)
if self.transformer.config.add_noise_in_inpaint_model:
masked_image = add_noise_to_reference_video(
masked_image, ratio=noise_aug_strength, generator=generator
)
new_mask_pixel_values = []
bs = 1
for i in range(0, masked_image.shape[0], bs):
mask_pixel_values_bs = masked_image[i : i + bs]
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
mask_pixel_values_bs = mask_pixel_values_bs.mode()
new_mask_pixel_values.append(mask_pixel_values_bs)
masked_image_latents = torch.cat(new_mask_pixel_values, dim=0)
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
else:
masked_image_latents = None
return mask, masked_image_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents=None,
video=None,
timestep=None,
is_strength_max=True,
return_noise=False,
return_video_latents=False,
):
shape = (
batch_size,
num_channels_latents,
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
height // self.vae_spatial_compression_ratio,
width // self.vae_spatial_compression_ratio,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if return_video_latents or (latents is None and not is_strength_max):
video = video.to(device=device, dtype=dtype)
bs = 1
new_video = []
for i in range(0, video.shape[0], bs):
video_bs = video[i : i + bs]
video_bs = self.vae.encode(video_bs)[0]
video_bs = video_bs.sample()
new_video.append(video_bs)
video = torch.cat(new_video, dim=0)
video = video * self.vae.config.scaling_factor
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
video_latents = video_latents.to(device=device, dtype=dtype)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latents = noise if is_strength_max else self.scheduler.scale_noise(video_latents, timestep, noise)
else:
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
if hasattr(self.scheduler, "init_noise_sigma"):
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
else:
if hasattr(self.scheduler, "init_noise_sigma"):
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_video_latents:
outputs += (video_latents,)
return outputs
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 49,
video: Union[torch.FloatTensor] = None,
mask_video: Union[torch.FloatTensor] = None,
masked_video_latents: Union[torch.FloatTensor] = None,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
strength: float = 1.0,
noise_aug_strength: float = 0.0563,
timesteps: Optional[List[int]] = None,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Examples:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
num_frames (`int`, *optional*):
Length of the video to be generated in seconds. This parameter influences the number of frames and
continuity of generated content.
video (`torch.FloatTensor`, *optional*):
A tensor representing an input video, which can be modified depending on the prompts provided.
mask_video (`torch.FloatTensor`, *optional*):
A tensor to specify areas of the video to be masked (omitted from generation).
masked_video_latents (`torch.FloatTensor`, *optional*):
Latents from masked portions of the video, utilized during image generation.
height (`int`, *optional*):
The height in pixels of the generated image or video frames.
width (`int`, *optional*):
The width in pixels of the generated image or video frames.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
inference time. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to exclude in image generation. If not defined, you need to provide
`negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
[`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
inference process.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
random seeds which helps in making generation deterministic.
latents (`torch.Tensor`, *optional*):
A pre-computed latent representation which can be used to guide the generation process.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the
outputs. If not provided, embeddings are generated from the `negative_prompt` argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
`prompt_embeds`.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
output_type (`str`, *optional*, defaults to `"latent"`):
The output format of the generated image. Choose between `PIL.Image` and `np.array` to define how you
want the results to be formatted.
return_dict (`bool`, *optional*, defaults to `True`):
If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
otherwise, a tuple containing the generated images and safety flags will be returned.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`,
*optional*):
A callback function (or a list of them) that will be executed at the end of each denoising step,
allowing for custom processing during generation.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
strength (`float`, *optional*, defaults to 1.0):
Affects the overall styling or quality of the generated output. Values closer to 1 usually provide
direct adherence to prompts.
Examples:
# Example usage of the function for generating images based on prompts.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Returns either a structured output containing generated images and their metadata when `return_dict` is
`True`, or a simpler tuple, where the first element is a list of generated images and the second
element indicates if any of them contain "not-safe-for-work" (NSFW) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = int(height // 16 * 16)
width = int(width // 16 * 16)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = self.transformer.dtype
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# 4. set timesteps
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, mu=1
)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
if video is not None:
batch_size, channels, num_frames, height_video, width_video = video.shape
init_video = self.image_processor.preprocess(
video.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
height=height,
width=width,
)
init_video = init_video.to(dtype=torch.float32)
init_video = init_video.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
else:
init_video = None
# Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_transformer = self.transformer.config.in_channels
return_image_latents = num_channels_transformer == num_channels_latents
# 5. Prepare latents.
latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents,
video=init_video,
timestep=latent_timestep,
is_strength_max=is_strength_max,
return_noise=True,
return_video_latents=return_image_latents,
)
if return_image_latents:
latents, noise, image_latents = latents_outputs
else:
latents, noise = latents_outputs
# 6. Prepare inpaint latents if it needs.
if mask_video is not None:
if (mask_video == 255).all():
mask = torch.zeros_like(latents).to(device, dtype)
# Use zero latents if we want to t2v.
if self.transformer.config.resize_inpaint_mask_directly:
mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype)
else:
mask_latents = torch.zeros_like(latents).to(device, dtype)
masked_video_latents = torch.zeros_like(latents).to(device, dtype)
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
masked_video_latents_input = (
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
)
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
else:
# Prepare mask latent variables
batch_size, channels, num_frames, height_video, width_video = mask_video.shape
mask_condition = self.mask_processor.preprocess(
mask_video.permute(0, 2, 1, 3, 4).reshape(
batch_size * num_frames, channels, height_video, width_video
),
height=height,
width=width,
)
mask_condition = mask_condition.to(dtype=torch.float32)
mask_condition = mask_condition.reshape(batch_size, num_frames, channels, height, width).permute(
0, 2, 1, 3, 4
)
if num_channels_transformer != num_channels_latents:
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
if masked_video_latents is None:
masked_video = (
init_video * (mask_condition_tile < 0.5)
+ torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
)
else:
masked_video = masked_video_latents
if self.transformer.config.resize_inpaint_mask_directly:
_, masked_video_latents = self.prepare_mask_latents(
None,
masked_video,
batch_size,
height,
width,
dtype,
device,
generator,
self.do_classifier_free_guidance,
noise_aug_strength=noise_aug_strength,
)
mask_latents = resize_mask(
1 - mask_condition, masked_video_latents, self.vae.config.cache_mag_vae
)
mask_latents = mask_latents.to(device, dtype) * self.vae.config.scaling_factor
else:
mask_latents, masked_video_latents = self.prepare_mask_latents(
mask_condition_tile,
masked_video,
batch_size,
height,
width,
dtype,
device,
generator,
self.do_classifier_free_guidance,
noise_aug_strength=noise_aug_strength,
)
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
masked_video_latents_input = (
torch.cat([masked_video_latents] * 2)
if self.do_classifier_free_guidance
else masked_video_latents
)
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
else:
inpaint_latents = None
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to(
device, dtype
)
else:
if num_channels_transformer != num_channels_latents:
mask = torch.zeros_like(latents).to(device, dtype)
if self.transformer.config.resize_inpaint_mask_directly:
mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype)
else:
mask_latents = torch.zeros_like(latents).to(device, dtype)
masked_video_latents = torch.zeros_like(latents).to(device, dtype)
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
masked_video_latents_input = (
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
)
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
else:
mask = torch.zeros_like(init_video[:, :1])
mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to(
device, dtype
)
inpaint_latents = None
# Check that sizes of mask, masked image and latents match
if num_channels_transformer != num_channels_latents:
num_channels_mask = mask_latents.shape[1]
num_channels_masked_image = masked_video_latents.shape[1]
if (
num_channels_latents + num_channels_mask + num_channels_masked_image
!= self.transformer.config.in_channels
):
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
# To latents.device
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if hasattr(self.scheduler, "scale_model_input"):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
inpaint_latents=inpaint_latents,
return_dict=False,
)[0]
if noise_pred.size()[1] != self.vae.config.latent_channels:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if num_channels_transformer == num_channels_latents:
init_latents_proper = image_latents
init_mask = mask
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper, torch.tensor([noise_timestep], noise)
)
else:
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep])
)
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = 1 / self.vae.config.scaling_factor * latents
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return EasyAnimatePipelineOutput(frames=video)
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class EasyAnimatePipelineOutput(BaseOutput):
r"""
Output class for EasyAnimate pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
......@@ -171,6 +171,21 @@ class AutoencoderKLLTXVideo(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLMagvit(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLMochi(metaclass=DummyObject):
_backends = ["torch"]
......@@ -396,6 +411,21 @@ class DiTTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class EasyAnimateTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FluxControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -407,6 +407,51 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class EasyAnimateControlPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class EasyAnimateInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class EasyAnimatePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxControlImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLMagvit
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_magvit_config(self):
return {
"in_channels": 3,
"latent_channels": 4,
"out_channels": 3,
"block_out_channels": [8, 8, 8, 8],
"down_block_types": [
"SpatialDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
],
"up_block_types": [
"SpatialUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
],
"layers_per_block": 1,
"norm_num_groups": 8,
"spatial_group_norm": True,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
height = 16
width = 16
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_magvit_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Not quite sure why this test fails. Revisit later.")
def test_effective_gradient_checkpointing(self):
pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment