Unverified Commit d7ffe601 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Hunyuan Video Framepack (#11428)

* add transformer

* add pipeline

* fixes

* make fix-copies

* update

* add flux mu shift

* update example snippet

* debug

* cleanup

* batch_size=1 optimization

* add pipeline test

* fix for model cpu offloading'

* add last_image support; credits: https://github.com/lllyasviel/FramePack/pull/167

* update example with flf2v

* update penguin url

* fix test

* address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071032371

* address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071087689



* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 10bee525
...@@ -52,6 +52,7 @@ The following models are available for the image-to-video pipeline: ...@@ -52,6 +52,7 @@ The following models are available for the image-to-video pipeline:
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | | [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). | | [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | | [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | lllyasviel's paper introducing a new technique for long-context video generation called [Framepack](https://arxiv.org/abs/2504.12626). |
## Quantization ## Quantization
......
...@@ -175,6 +175,7 @@ else: ...@@ -175,6 +175,7 @@ else:
"HunyuanDiT2DControlNetModel", "HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel", "HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel", "HunyuanDiT2DMultiControlNetModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel", "HunyuanVideoTransformer3DModel",
"I2VGenXLUNet", "I2VGenXLUNet",
"Kandinsky3UNet", "Kandinsky3UNet",
...@@ -376,6 +377,7 @@ else: ...@@ -376,6 +377,7 @@ else:
"HunyuanDiTPAGPipeline", "HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline", "HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline", "HunyuanVideoPipeline",
"I2VGenXLPipeline", "I2VGenXLPipeline",
...@@ -770,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -770,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel, HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel, HunyuanDiT2DMultiControlNetModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
I2VGenXLUNet, I2VGenXLUNet,
Kandinsky3UNet, Kandinsky3UNet,
...@@ -950,6 +953,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -950,6 +953,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline, HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline, HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline, HunyuanVideoPipeline,
I2VGenXLPipeline, I2VGenXLPipeline,
......
...@@ -79,6 +79,7 @@ if is_torch_available(): ...@@ -79,6 +79,7 @@ if is_torch_available():
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
...@@ -156,6 +157,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -156,6 +157,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel, FluxTransformer2DModel,
HiDreamImageTransformer2DModel, HiDreamImageTransformer2DModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
LatteTransformer3DModel, LatteTransformer3DModel,
LTXVideoTransformer3DModel, LTXVideoTransformer3DModel,
......
...@@ -23,6 +23,7 @@ if is_torch_available(): ...@@ -23,6 +23,7 @@ if is_torch_available():
from .transformer_flux import FluxTransformer2DModel from .transformer_flux import FluxTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel from .transformer_mochi import MochiTransformer3DModel
......
# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
from .transformer_hunyuan_video import (
HunyuanVideoConditionEmbedding,
HunyuanVideoPatchEmbed,
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenRefiner,
HunyuanVideoTransformerBlock,
)
logger = get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.rope_dim = rope_dim
self.theta = theta
def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
height = height // self.patch_size
width = width // self.patch_size
grid = torch.meshgrid(
frame_indices.to(device=device, dtype=torch.float32),
torch.arange(0, height, device=device, dtype=torch.float32),
torch.arange(0, width, device=device, dtype=torch.float32),
indexing="ij",
) # 3 * [W, H, T]
grid = torch.stack(grid, dim=0) # [3, W, H, T]
freqs = []
for i in range(3):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
class FramepackClipVisionProjection(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.up = nn.Linear(in_channels, out_channels * 3)
self.down = nn.Linear(out_channels * 3, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.up(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.down(hidden_states)
return hidden_states
class HunyuanVideoHistoryPatchEmbed(nn.Module):
def __init__(self, in_channels: int, inner_dim: int):
super().__init__()
self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
def forward(
self,
latents_clean: Optional[torch.Tensor] = None,
latents_clean_2x: Optional[torch.Tensor] = None,
latents_clean_4x: Optional[torch.Tensor] = None,
):
if latents_clean is not None:
latents_clean = self.proj(latents_clean)
latents_clean = latents_clean.flatten(2).transpose(1, 2)
if latents_clean_2x is not None:
latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
latents_clean_2x = self.proj_2x(latents_clean_2x)
latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
if latents_clean_4x is not None:
latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
latents_clean_4x = self.proj_4x(latents_clean_4x)
latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
return latents_clean, latents_clean_2x, latents_clean_4x
class HunyuanVideoFramepackTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
):
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoHistoryPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: int = 2,
patch_size_t: int = 1,
qk_norm: str = "rms_norm",
guidance_embeds: bool = True,
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
image_condition_type: Optional[str] = None,
has_image_proj: int = False,
image_proj_dim: int = 1152,
has_clean_x_embedder: int = False,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
# 2. RoPE
self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
# Framepack specific modules
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
self.clean_x_embedder = None
if has_clean_x_embedder:
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
self.use_gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
image_embeds: torch.Tensor,
indices_latents: torch.Tensor,
guidance: Optional[torch.Tensor] = None,
latents_clean: Optional[torch.Tensor] = None,
indices_latents_clean: Optional[torch.Tensor] = None,
latents_history_2x: Optional[torch.Tensor] = None,
indices_latents_history_2x: Optional[torch.Tensor] = None,
latents_history_4x: Optional[torch.Tensor] = None,
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
if indices_latents is None:
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
hidden_states = self.x_embedder(hidden_states)
image_rotary_emb = self.rope(
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
)
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
latents_clean, latents_history_2x, latents_history_4x
)
if latents_clean is not None and indices_latents_clean is not None:
image_rotary_emb_clean = self.rope(
frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
)
if latents_history_2x is not None and indices_latents_history_2x is not None:
image_rotary_emb_history_2x = self.rope(
frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
)
if latents_history_4x is not None and indices_latents_history_4x is not None:
image_rotary_emb_history_4x = self.rope(
frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
)
hidden_states, image_rotary_emb = self._pack_history_states(
hidden_states,
latents_clean,
latents_history_2x,
latents_history_4x,
image_rotary_emb,
image_rotary_emb_clean,
image_rotary_emb_history_2x,
image_rotary_emb_history_4x,
post_patch_height,
post_patch_width,
)
temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
encoder_hidden_states_image = self.image_projection(image_embeds)
attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
# must cat before (not after) encoder_hidden_states, due to attn masking
encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
if batch_size == 1:
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
attention_mask = None
else:
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
hidden_states = hidden_states[:, -original_context_length:]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
def _pack_history_states(
self,
hidden_states: torch.Tensor,
latents_clean: Optional[torch.Tensor] = None,
latents_history_2x: Optional[torch.Tensor] = None,
latents_history_4x: Optional[torch.Tensor] = None,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
height: int = None,
width: int = None,
):
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
if latents_clean is not None and image_rotary_emb_clean is not None:
hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
return hidden_states, tuple(image_rotary_emb)
def _pad_rotary_emb(
self,
image_rotary_emb: Tuple[torch.Tensor],
height: int,
width: int,
kernel_size: Tuple[int, int, int],
):
# freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
freqs_cos, freqs_sin = image_rotary_emb
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
return freqs_cos, freqs_sin
def _pad_for_3d_conv(x, kernel_size):
if isinstance(x, (tuple, list)):
return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
b, c, t, h, w = x.shape
pt, ph, pw = kernel_size
pad_t = (pt - (t % pt)) % pt
pad_h = (ph - (h % ph)) % ph
pad_w = (pw - (w % pw)) % pw
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
def _center_down_sample_3d(x, kernel_size):
if isinstance(x, (tuple, list)):
return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
...@@ -227,6 +227,7 @@ else: ...@@ -227,6 +227,7 @@ else:
"HunyuanVideoPipeline", "HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline", "HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
] ]
_import_structure["kandinsky"] = [ _import_structure["kandinsky"] = [
"KandinskyCombinedPipeline", "KandinskyCombinedPipeline",
...@@ -589,6 +590,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -589,6 +590,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .hidream_image import HiDreamImagePipeline from .hidream_image import HiDreamImagePipeline
from .hunyuan_video import ( from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline, HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline, HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline, HunyuanVideoPipeline,
) )
......
...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"] _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
_import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"]
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"] _import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -36,6 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -36,6 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline from .pipeline_hunyuan_video import HunyuanVideoPipeline
from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
else: else:
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
import torch import torch
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput
...@@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput): ...@@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput):
""" """
frames: torch.Tensor frames: torch.Tensor
@dataclass
class HunyuanVideoFramepackPipelineOutput(BaseOutput):
r"""
Output class for HunyuanVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor
corresponds to a latent that decodes to multiple frames.
"""
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]]
...@@ -565,6 +565,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): ...@@ -565,6 +565,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanVideoTransformer3DModel(metaclass=DummyObject): class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -692,6 +692,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject): ...@@ -692,6 +692,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoFramepackPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject): class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
LlamaConfig,
LlamaModel,
LlamaTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
)
from diffusers import (
AutoencoderKLHunyuanVideo,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoFramepackPipeline,
HunyuanVideoFramepackTransformer3DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
)
enable_full_determinism()
class HunyuanVideoFramepackPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = HunyuanVideoFramepackPipeline
params = frozenset(
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
)
batch_params = frozenset(["image", "prompt"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideoFramepackTransformer3DModel(
in_channels=4,
out_channels=4,
num_attention_heads=2,
attention_head_dim=10,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=2,
patch_size_t=1,
guidance_embeds=True,
text_embed_dim=16,
pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4),
image_condition_type=None,
has_image_proj=True,
image_proj_dim=32,
has_clean_x_embedder=True,
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo(
in_channels=3,
out_channels=3,
latent_channels=4,
down_block_types=(
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
up_block_types=(
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
layers_per_block=1,
act_fn="silu",
norm_num_groups=4,
scaling_factor=0.476986,
spatial_compression_ratio=8,
temporal_compression_ratio=4,
mid_block_add_attention=True,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
llama_text_encoder_config = LlamaConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=16,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = LlamaModel(llama_text_encoder_config)
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = SiglipImageProcessor.from_pretrained(
"hf-internal-testing/tiny-random-SiglipVisionModel", size={"height": 30, "width": 30}
)
image_encoder = SiglipVisionModel.from_pretrained("hf-internal-testing/tiny-random-SiglipVisionModel")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"feature_extractor": feature_extractor,
"image_encoder": image_encoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 32
image_width = 32
image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"prompt_template": {
"template": "{}",
"crop_start": 0,
},
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.5,
"height": image_height,
"width": image_width,
"num_frames": 9,
"latent_window_size": 3,
"max_sequence_length": 256,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (13, 3, 32, 32))
expected_video = torch.randn(13, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
# Seems to require higher tolerance than the other tests
expected_diff_max = 0.6
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment