Unverified Commit d7ffe601 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Hunyuan Video Framepack (#11428)

* add transformer

* add pipeline

* fixes

* make fix-copies

* update

* add flux mu shift

* update example snippet

* debug

* cleanup

* batch_size=1 optimization

* add pipeline test

* fix for model cpu offloading'

* add last_image support; credits: https://github.com/lllyasviel/FramePack/pull/167

* update example with flf2v

* update penguin url

* fix test

* address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071032371

* address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071087689



* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 10bee525
......@@ -52,6 +52,7 @@ The following models are available for the image-to-video pipeline:
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | lllyasviel's paper introducing a new technique for long-context video generation called [Framepack](https://arxiv.org/abs/2504.12626). |
## Quantization
......
......@@ -175,6 +175,7 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
......@@ -376,6 +377,7 @@ else:
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline",
"I2VGenXLPipeline",
......@@ -770,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
......@@ -950,6 +953,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
I2VGenXLPipeline,
......
......@@ -79,6 +79,7 @@ if is_torch_available():
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
......@@ -156,6 +157,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
......
......@@ -23,6 +23,7 @@ if is_torch_available():
from .transformer_flux import FluxTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
......
# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
from .transformer_hunyuan_video import (
HunyuanVideoConditionEmbedding,
HunyuanVideoPatchEmbed,
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenRefiner,
HunyuanVideoTransformerBlock,
)
logger = get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.rope_dim = rope_dim
self.theta = theta
def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
height = height // self.patch_size
width = width // self.patch_size
grid = torch.meshgrid(
frame_indices.to(device=device, dtype=torch.float32),
torch.arange(0, height, device=device, dtype=torch.float32),
torch.arange(0, width, device=device, dtype=torch.float32),
indexing="ij",
) # 3 * [W, H, T]
grid = torch.stack(grid, dim=0) # [3, W, H, T]
freqs = []
for i in range(3):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
class FramepackClipVisionProjection(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.up = nn.Linear(in_channels, out_channels * 3)
self.down = nn.Linear(out_channels * 3, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.up(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.down(hidden_states)
return hidden_states
class HunyuanVideoHistoryPatchEmbed(nn.Module):
def __init__(self, in_channels: int, inner_dim: int):
super().__init__()
self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
def forward(
self,
latents_clean: Optional[torch.Tensor] = None,
latents_clean_2x: Optional[torch.Tensor] = None,
latents_clean_4x: Optional[torch.Tensor] = None,
):
if latents_clean is not None:
latents_clean = self.proj(latents_clean)
latents_clean = latents_clean.flatten(2).transpose(1, 2)
if latents_clean_2x is not None:
latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
latents_clean_2x = self.proj_2x(latents_clean_2x)
latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
if latents_clean_4x is not None:
latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
latents_clean_4x = self.proj_4x(latents_clean_4x)
latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
return latents_clean, latents_clean_2x, latents_clean_4x
class HunyuanVideoFramepackTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
):
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoHistoryPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: int = 2,
patch_size_t: int = 1,
qk_norm: str = "rms_norm",
guidance_embeds: bool = True,
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
image_condition_type: Optional[str] = None,
has_image_proj: int = False,
image_proj_dim: int = 1152,
has_clean_x_embedder: int = False,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
# 2. RoPE
self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
# Framepack specific modules
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
self.clean_x_embedder = None
if has_clean_x_embedder:
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
self.use_gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
image_embeds: torch.Tensor,
indices_latents: torch.Tensor,
guidance: Optional[torch.Tensor] = None,
latents_clean: Optional[torch.Tensor] = None,
indices_latents_clean: Optional[torch.Tensor] = None,
latents_history_2x: Optional[torch.Tensor] = None,
indices_latents_history_2x: Optional[torch.Tensor] = None,
latents_history_4x: Optional[torch.Tensor] = None,
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
if indices_latents is None:
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
hidden_states = self.x_embedder(hidden_states)
image_rotary_emb = self.rope(
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
)
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
latents_clean, latents_history_2x, latents_history_4x
)
if latents_clean is not None and indices_latents_clean is not None:
image_rotary_emb_clean = self.rope(
frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
)
if latents_history_2x is not None and indices_latents_history_2x is not None:
image_rotary_emb_history_2x = self.rope(
frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
)
if latents_history_4x is not None and indices_latents_history_4x is not None:
image_rotary_emb_history_4x = self.rope(
frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
)
hidden_states, image_rotary_emb = self._pack_history_states(
hidden_states,
latents_clean,
latents_history_2x,
latents_history_4x,
image_rotary_emb,
image_rotary_emb_clean,
image_rotary_emb_history_2x,
image_rotary_emb_history_4x,
post_patch_height,
post_patch_width,
)
temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
encoder_hidden_states_image = self.image_projection(image_embeds)
attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
# must cat before (not after) encoder_hidden_states, due to attn masking
encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
if batch_size == 1:
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
attention_mask = None
else:
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
hidden_states = hidden_states[:, -original_context_length:]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
def _pack_history_states(
self,
hidden_states: torch.Tensor,
latents_clean: Optional[torch.Tensor] = None,
latents_history_2x: Optional[torch.Tensor] = None,
latents_history_4x: Optional[torch.Tensor] = None,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
height: int = None,
width: int = None,
):
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
if latents_clean is not None and image_rotary_emb_clean is not None:
hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
return hidden_states, tuple(image_rotary_emb)
def _pad_rotary_emb(
self,
image_rotary_emb: Tuple[torch.Tensor],
height: int,
width: int,
kernel_size: Tuple[int, int, int],
):
# freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
freqs_cos, freqs_sin = image_rotary_emb
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
return freqs_cos, freqs_sin
def _pad_for_3d_conv(x, kernel_size):
if isinstance(x, (tuple, list)):
return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
b, c, t, h, w = x.shape
pt, ph, pw = kernel_size
pad_t = (pt - (t % pt)) % pt
pad_h = (ph - (h % ph)) % ph
pad_w = (pw - (w % pw)) % pw
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
def _center_down_sample_3d(x, kernel_size):
if isinstance(x, (tuple, list)):
return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
......@@ -227,6 +227,7 @@ else:
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
......@@ -589,6 +590,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .hidream_image import HiDreamImagePipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
......
......@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
_import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"]
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
......@@ -36,6 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline
from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
else:
......
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
import torch
from diffusers.utils import BaseOutput
......@@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput):
"""
frames: torch.Tensor
@dataclass
class HunyuanVideoFramepackPipelineOutput(BaseOutput):
r"""
Output class for HunyuanVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor
corresponds to a latent that decodes to multiple frames.
"""
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]]
......@@ -565,6 +565,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -692,6 +692,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoFramepackPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
LlamaConfig,
LlamaModel,
LlamaTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
)
from diffusers import (
AutoencoderKLHunyuanVideo,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoFramepackPipeline,
HunyuanVideoFramepackTransformer3DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
)
enable_full_determinism()
class HunyuanVideoFramepackPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = HunyuanVideoFramepackPipeline
params = frozenset(
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
)
batch_params = frozenset(["image", "prompt"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideoFramepackTransformer3DModel(
in_channels=4,
out_channels=4,
num_attention_heads=2,
attention_head_dim=10,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=2,
patch_size_t=1,
guidance_embeds=True,
text_embed_dim=16,
pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4),
image_condition_type=None,
has_image_proj=True,
image_proj_dim=32,
has_clean_x_embedder=True,
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo(
in_channels=3,
out_channels=3,
latent_channels=4,
down_block_types=(
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
up_block_types=(
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
layers_per_block=1,
act_fn="silu",
norm_num_groups=4,
scaling_factor=0.476986,
spatial_compression_ratio=8,
temporal_compression_ratio=4,
mid_block_add_attention=True,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
llama_text_encoder_config = LlamaConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=16,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = LlamaModel(llama_text_encoder_config)
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = SiglipImageProcessor.from_pretrained(
"hf-internal-testing/tiny-random-SiglipVisionModel", size={"height": 30, "width": 30}
)
image_encoder = SiglipVisionModel.from_pretrained("hf-internal-testing/tiny-random-SiglipVisionModel")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"feature_extractor": feature_extractor,
"image_encoder": image_encoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 32
image_width = 32
image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"prompt_template": {
"template": "{}",
"crop_start": 0,
},
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.5,
"height": image_height,
"width": image_width,
"num_frames": 9,
"latent_window_size": 3,
"max_sequence_length": 256,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (13, 3, 32, 32))
expected_video = torch.randn(13, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
# Seems to require higher tolerance than the other tests
expected_diff_max = 0.6
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment