Unverified Commit a138d71e authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

HunyuanImage21 (#12333)



* add hunyuanimage2.1


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent bc403988
...@@ -36,6 +36,8 @@ if is_torch_available(): ...@@ -36,6 +36,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"] _import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
...@@ -91,6 +93,7 @@ if is_torch_available(): ...@@ -91,6 +93,7 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
...@@ -133,6 +136,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -133,6 +136,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLCosmos, AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo, AutoencoderKLLTXVideo,
AutoencoderKLMagvit, AutoencoderKLMagvit,
...@@ -182,6 +187,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -182,6 +187,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel, FluxTransformer2DModel,
HiDreamImageTransformer2DModel, HiDreamImageTransformer2DModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel, HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel, Kandinsky5Transformer3DModel,
......
...@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro ...@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_mochi import AutoencoderKLMochi
......
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageResnetBlock(nn.Module):
r"""
Residual block with two convolutions and optional channel change.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x):
# Apply shortcut connection
residual = x
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
# Add residual connection
return x + residual
class HunyuanImageAttentionBlock(nn.Module):
r"""
Self-attention with a single head.
Args:
in_channels (int): The number of channels in the input tensor.
"""
def __init__(self, in_channels: int):
super().__init__()
# layers
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
self.proj = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
identity = x
x = self.norm(x)
# compute query, key, value
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, height, width = query.shape
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
# apply attention
x = F.scaled_dot_product_attention(query, key, value)
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
# output projection
x = self.proj(x)
return x + identity
class HunyuanImageDownsample(nn.Module):
"""
Downsampling block for spatial reduction.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
if out_channels % factor != 0:
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.group_size = factor * in_channels // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
h = h.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = x.shape
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageUpsample(nn.Module):
"""
Upsampling block for spatial expansion.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
self.repeats = factor * out_channels // in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
h = h.reshape(B, C // 4, H * 2, W * 2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
B, C, H, W = shortcut.shape
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
return h + shortcut
class HunyuanImageMidBlock(nn.Module):
"""
Middle block for HunyuanImageVAE encoder and decoder.
Args:
in_channels (int): Number of input channels.
num_layers (int): Number of layers.
"""
def __init__(self, in_channels: int, num_layers: int = 1):
super().__init__()
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
attentions = []
for _ in range(num_layers):
attentions.append(HunyuanImageAttentionBlock(in_channels))
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnets[0](x)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
x = attn(x)
x = resnet(x)
return x
class HunyuanImageEncoder2D(nn.Module):
r"""
Encoder network that compresses input to latent representation.
Args:
in_channels (int): Number of input channels.
z_channels (int): Number of latent channels.
block_out_channels (list of int): Output channels for each block.
num_res_blocks (int): Number of residual blocks per block.
spatial_compression_ratio (int): Spatial downsampling factor.
non_linearity (str): Type of non-linearity to use. Default is "silu".
downsample_match_channel (bool): Whether to match channels during downsampling.
"""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
non_linearity: str = "silu",
downsample_match_channel: bool = True,
):
super().__init__()
if block_out_channels[-1] % (2 * z_channels) != 0:
raise ValueError(
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
)
self.in_channels = in_channels
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.spatial_compression_ratio = spatial_compression_ratio
self.group_size = block_out_channels[-1] // (2 * z_channels)
self.nonlinearity = get_activation(non_linearity)
# init block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
block_in_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
# residual blocks
for _ in range(num_res_blocks):
self.down_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# downsample block
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if downsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.down_blocks.append(
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# middle blocks
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
# output blocks
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
## downsamples
for down_block in self.down_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(down_block, x)
else:
x = down_block(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.mid_block, x)
else:
x = self.mid_block(x)
## head
B, C, H, W = x.shape
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
x = self.norm_out(x)
x = self.nonlinearity(x)
x = self.conv_out(x)
return x + residual
class HunyuanImageDecoder2D(nn.Module):
r"""
Decoder network that reconstructs output from latent representation.
Args:
z_channels : int
Number of latent channels.
out_channels : int
Number of output channels.
block_out_channels : Tuple[int, ...]
Output channels for each block.
num_res_blocks : int
Number of residual blocks per block.
spatial_compression_ratio : int
Spatial upsampling factor.
upsample_match_channel : bool
Whether to match channels during upsampling.
non_linearity (str): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
upsample_match_channel: bool = True,
non_linearity: str = "silu",
):
super().__init__()
if block_out_channels[0] % z_channels != 0:
raise ValueError(
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
)
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.repeat = block_out_channels[0] // z_channels
self.spatial_compression_ratio = spatial_compression_ratio
self.nonlinearity = get_activation(non_linearity)
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# Middle blocks with attention
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
# Upsampling blocks
block_in_channel = block_out_channels[0]
self.up_blocks = nn.ModuleList()
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
for _ in range(self.num_res_blocks + 1):
self.up_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if upsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
block_in_channel = block_out_channel
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid_block, h)
else:
h = self.mid_block(h)
for up_block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(up_block, h)
else:
h = up_block(h)
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# fmt: off
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
spatial_compression_ratio: int,
sample_size: int,
scaling_factor: float = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
) -> None:
# fmt: on
super().__init__()
self.encoder = HunyuanImageEncoder2D(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageDecoder2D(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
# Tiling and slicing configuration
self.use_slicing = False
self.use_tiling = False
# Tiling parameters
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // spatial_compression_ratio
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_size: Optional[int] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
allow processing larger images.
Args:
tile_sample_min_size (`int`, *optional*):
The minimum size required for a sample to be separated into tiles across the spatial dimension.
tile_overlap_factor (`float`, *optional*):
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
"""
self.use_tiling = True
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self.tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
batch_size, num_channels, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode input using spatial tiling strategy.
Args:
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
Returns:
`torch.Tensor`:
The latent representation of the encoded images.
"""
_, _, _, height, width = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode latent using spatial tiling strategy.
Args:
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, height, width = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageRefinerCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
return self.conv(hidden_states)
class HunyuanImageRefinerRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.norm(x)
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, frames, height, width = query.shape
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
# batch_size, 1, frames * height * width, channels
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
x = self.proj_out(x)
return x + identity
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
@staticmethod
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
Args:
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
r1: temporal upsampling factor
r2: height upsampling factor
r3: width upsampling factor
"""
b, packed_c, f, h, w = tensor.shape
factor = r1 * r2 * r3
c = packed_c // factor
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
h = h[:, : h.shape[1] // 2]
# shortcut computation
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
else:
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
return h + shortcut
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
@staticmethod
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
This packs spatial/temporal dimensions into channels (opposite of upsample)
"""
b, c, packed_f, packed_h, packed_w = tensor.shape
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
h = torch.cat([h, h], dim=1)
# shortcut computation
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
else:
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageRefinerResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class HunyuanImageRefinerMidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
add_attention: bool = True,
) -> None:
super().__init__()
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
else:
attentions.append(None)
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanImageRefinerDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample_out_channels: Optional[int] = None,
add_temporal_downsample: int = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample_out_channels is not None:
self.downsamplers = nn.ModuleList(
[
HunyuanImageRefinerDownsampleDCAE(
out_channels,
out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample_out_channels: Optional[int] = None,
add_temporal_upsample: bool = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample_out_channels is not None:
self.upsamplers = nn.ModuleList(
[
HunyuanImageRefinerUpsampleDCAE(
out_channels,
out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerEncoder3D(nn.Module):
r"""
3D vae encoder for HunyuanImageRefiner.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 16,
downsample_match_channel: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.group_size = block_out_channels[-1] // self.out_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
output_channel = block_out_channels[i]
if not add_spatial_downsample:
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=None,
add_temporal_downsample=False,
)
input_channel = output_channel
else:
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
input_channel = downsample_out_channels
self.down_blocks.append(down_block)
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states += short_cut
return hidden_states
class HunyuanImageRefinerDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
upsample_match_channel: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.in_channels = in_channels
self.out_channels = out_channels
self.repeat = block_out_channels[0] // self.in_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
# up
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
output_channel = block_out_channels[i]
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
if add_spatial_upsample or add_temporal_upsample:
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
input_channel = upsample_out_channels
else:
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=None,
add_temporal_upsample=False,
)
input_channel = output_channel
self.up_blocks.append(up_block)
# out
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
scaling_factor: float = 1.03682,
) -> None:
super().__init__()
self.encoder = HunyuanImageRefinerEncoder3D(
in_channels=in_channels,
out_channels=latent_channels * 2,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageRefinerDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
x = self.encoder(x)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_lumina2 import Lumina2Transformer2DModel
......
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepTextProjEmbeddings,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
apply_rotary_emb(
query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
),
query[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
key = torch.cat(
[
apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
key[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
else:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
# 5. Attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class HunyuanImagePatchEmbed(nn.Module):
def __init__(
self,
patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.patch_size = patch_size
if len(patch_size) == 2:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
elif len(patch_size) == 3:
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
else:
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
class HunyuanImageByT5TextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, out_features: int):
super().__init__()
self.norm = nn.LayerNorm(in_features)
self.linear_1 = nn.Linear(in_features, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, out_features)
self.act_fn = nn.GELU()
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(encoder_hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_3(hidden_states)
return hidden_states
class HunyuanImageAdaNorm(nn.Module):
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
self.linear = nn.Linear(in_features, out_features)
self.nonlinearity = nn.SiLU()
def forward(
self, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
return gate_msa, gate_mlp
class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
guidance_embeds: bool = False,
use_meanflow: bool = False,
):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_meanflow = use_meanflow
self.time_proj_r = None
self.timestep_embedder_r = None
if use_meanflow:
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = None
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self,
timestep: torch.Tensor,
timestep_r: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
if timestep_r is not None:
timesteps_proj_r = self.time_proj_r(timestep_r)
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
conditioning = timesteps_emb + guidance_emb
else:
conditioning = timesteps_emb
return conditioning
# IndividualTokenRefinerBlock
@maybe_allow_in_graph
class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int, # 28
attention_head_dim: int, # 128
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
)
gate_msa, gate_mlp = self.norm_out(temb)
hidden_states = hidden_states + attn_output * gate_msa
ff_output = self.ff(self.norm2(hidden_states))
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
class HunyuanImageIndividualTokenRefiner(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
self.refiner_blocks = nn.ModuleList(
[
HunyuanImageIndividualTokenRefinerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
batch_size = attention_mask.shape[0]
seq_len = attention_mask.shape[1]
attention_mask = attention_mask.to(hidden_states.device)
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True
for block in self.refiner_blocks:
hidden_states = block(hidden_states, temb, self_attn_mask)
return hidden_states
# txt_in
class HunyuanImageTokenRefiner(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=hidden_size, pooled_projection_dim=in_channels
)
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
self.token_refiner = HunyuanImageIndividualTokenRefiner(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=num_layers,
mlp_width_ratio=mlp_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_hidden_states = hidden_states.mean(dim=1)
else:
original_dtype = hidden_states.dtype
mask_float = attention_mask.float().unsqueeze(-1)
pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
pooled_hidden_states = pooled_hidden_states.to(original_dtype)
temb = self.time_text_embed(timestep, pooled_hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
return hidden_states
class HunyuanImageRotaryPosEmbed(nn.Module):
def __init__(
self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
) -> None:
super().__init__()
if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
if not len(patch_size) == len(rope_dim):
raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
self.patch_size = patch_size
self.rope_dim = rope_dim
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hidden_states.ndim == 5:
_, _, frame, height, width = hidden_states.shape
patch_size_frame, patch_size_height, patch_size_width = self.patch_size
rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
elif hidden_states.ndim == 4:
_, _, height, width = hidden_states.shape
patch_size_height, patch_size_width = self.patch_size
rope_sizes = [height // patch_size_height, width // patch_size_width]
else:
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
axes_grids = []
for i in range(len(rope_sizes)):
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
axes_grids.append(grid)
grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
grid = torch.stack(grid, dim=0) # [2, H, W]
freqs = []
for i in range(len(rope_sizes)):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class HunyuanImageSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanImageAttnProcessor(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class HunyuanImageTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanImageAttnProcessor(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 3. Modulation and residual connection
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
image_condition_type (`str`, *optional*, defaults to `None`):
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
tokens in the latent stream and apply conditioning.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
"HunyuanImagePatchEmbed",
"HunyuanImageTokenRefiner",
]
_repeated_blocks = [
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
]
@register_to_config
def __init__(
self,
in_channels: int = 64,
out_channels: int = 64,
num_attention_heads: int = 28,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: Tuple[int, int] = (1, 1),
qk_norm: str = "rms_norm",
guidance_embeds: bool = False,
text_embed_dim: int = 3584,
text_embed_2_dim: Optional[int] = None,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (64, 64),
use_meanflow: bool = False,
) -> None:
super().__init__()
if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
self.context_embedder = HunyuanImageTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
if text_embed_2_dim is not None:
self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
else:
self.context_embedder_2 = None
self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
# 2. RoPE
self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanImageTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanImageSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
timestep_r: Optional[torch.LongTensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
encoder_attention_mask_2: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
sizes = (height, width)
elif hidden_states.ndim == 5:
batch_size, channels, frame, height, width = hidden_states.shape
sizes = (frame, height, width)
else:
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
encoder_attention_mask = encoder_attention_mask.bool()
temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
# reorder and combine text tokens: combine valid tokens first, then padding
new_encoder_hidden_states = []
new_encoder_attention_mask = []
for text, text_mask, text_2, text_mask_2 in zip(
encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
):
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
new_encoder_hidden_states.append(
torch.cat(
[
text_2[text_mask_2], # valid byt5
text[text_mask], # valid mllm
text_2[~text_mask_2], # invalid byt5
text[~text_mask], # invalid mllm
],
dim=0,
)
)
# Apply same reordering to attention masks
new_encoder_attention_mask.append(
torch.cat(
[
text_mask_2[text_mask_2],
text_mask[text_mask],
text_mask_2[~text_mask_2],
text_mask[~text_mask],
],
dim=0,
)
)
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# 3. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 4. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# 5. unpatchify
# reshape: [batch_size, *post_patch_dims, channels, *patch_size]
out_channels = self.config.out_channels
reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
hidden_states = hidden_states.reshape(*reshape_dims)
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
ndim = len(post_patch_sizes)
permute_pattern = [0, ndim + 1] # batch, channels
for i in range(ndim):
permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
hidden_states = hidden_states.permute(*permute_pattern)
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
final_dims = [batch_size, out_channels] + [
post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
]
hidden_states = hidden_states.reshape(*final_dims)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
...@@ -130,8 +130,14 @@ class PipelineState: ...@@ -130,8 +130,14 @@ class PipelineState:
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
intermediates dict. intermediates dict.
""" """
if name in self.values: # Use object.__getattribute__ to avoid infinite recursion during deepcopy
return self.values[name] try:
values = object.__getattribute__(self, "values")
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
if name in values:
return values[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self): def __repr__(self):
...@@ -2492,6 +2498,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -2492,6 +2498,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
""" """
if state is None: if state is None:
state = PipelineState() state = PipelineState()
else:
state = deepcopy(state)
# Make a copy of the input kwargs # Make a copy of the input kwargs
passed_kwargs = kwargs.copy() passed_kwargs = kwargs.copy()
......
...@@ -238,19 +238,27 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks): ...@@ -238,19 +238,27 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad() @torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_input_fields = { guider_inputs = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), "encoder_hidden_states": (
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), getattr(block_state, "prompt_embeds", None),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"), getattr(block_state, "negative_prompt_embeds", None),
),
"encoder_hidden_states_mask": (
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
} }
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state: for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer) components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
# YiYi TODO: add cache context # YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer( guider_state_batch.noise_pred = components.transformer(
...@@ -328,19 +336,27 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks): ...@@ -328,19 +336,27 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad() @torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_input_fields = { guider_inputs = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), "encoder_hidden_states": (
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), getattr(block_state, "prompt_embeds", None),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"), getattr(block_state, "negative_prompt_embeds", None),
),
"encoder_hidden_states_mask": (
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
} }
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state: for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer) components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
# YiYi TODO: add cache context # YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer( guider_state_batch.noise_pred = components.transformer(
......
...@@ -201,27 +201,41 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks): ...@@ -201,27 +201,41 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState: ) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = { guider_inputs = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), "prompt_embeds": (
"time_ids": ("add_time_ids", "negative_add_time_ids"), getattr(block_state, "prompt_embeds", None),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), getattr(block_state, "negative_prompt_embeds", None),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), ),
"time_ids": (
getattr(block_state, "add_time_ids", None),
getattr(block_state, "negative_add_time_ids", None),
),
"text_embeds": (
getattr(block_state, "pooled_prompt_embeds", None),
getattr(block_state, "negative_pooled_prompt_embeds", None),
),
"image_embeds": (
getattr(block_state, "ip_adapter_embeds", None),
getattr(block_state, "negative_ip_adapter_embeds", None),
),
} }
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare mini‐batches according to guidance method and `guider_input_fields` # The guider splits model inputs into separate batches for conditional/unconditional predictions.
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# e.g. for CFG, we prepare two batches: one for uncond, one for cond # you will get a guider_state with two batches:
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds # guider_state = [
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch # run the denoiser for each guidance batch
for guider_state_batch in guider_state: for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet) components.guider.prepare_models(components.unet)
cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds") prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual # Predict the noise residual
...@@ -344,11 +358,23 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks): ...@@ -344,11 +358,23 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = { guider_inputs = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), "prompt_embeds": (
"time_ids": ("add_time_ids", "negative_add_time_ids"), getattr(block_state, "prompt_embeds", None),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), getattr(block_state, "negative_prompt_embeds", None),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), ),
"time_ids": (
getattr(block_state, "add_time_ids", None),
getattr(block_state, "negative_add_time_ids", None),
),
"text_embeds": (
getattr(block_state, "pooled_prompt_embeds", None),
getattr(block_state, "negative_pooled_prompt_embeds", None),
),
"image_embeds": (
getattr(block_state, "ip_adapter_embeds", None),
getattr(block_state, "negative_ip_adapter_embeds", None),
),
} }
# cond_scale for the timestep (controlnet input) # cond_scale for the timestep (controlnet input)
...@@ -369,12 +395,15 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks): ...@@ -369,12 +395,15 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# guided denoiser step # guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare mini‐batches according to guidance method and `guider_input_fields` # The guider splits model inputs into separate batches for conditional/unconditional predictions.
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# e.g. for CFG, we prepare two batches: one for uncond, one for cond # you will get a guider_state with two batches:
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds # guider_state = [
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch # run the denoiser for each guidance batch
for guider_state_batch in guider_state: for guider_state_batch in guider_state:
......
...@@ -94,25 +94,30 @@ class WanLoopDenoiser(ModularPipelineBlocks): ...@@ -94,25 +94,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState: ) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = { guider_inputs = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), "prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
} }
transformer_dtype = components.transformer.dtype transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare mini‐batches according to guidance method and `guider_input_fields` # The guider splits model inputs into separate batches for conditional/unconditional predictions.
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# e.g. for CFG, we prepare two batches: one for uncond, one for cond # you will get a guider_state with two batches:
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds # guider_state = [
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch # run the denoiser for each guidance batch
for guider_state_batch in guider_state: for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer) components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds") prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual # Predict the noise residual
......
...@@ -241,6 +241,7 @@ else: ...@@ -241,6 +241,7 @@ else:
"HunyuanVideoImageToVideoPipeline", "HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline", "HunyuanVideoFramepackPipeline",
] ]
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
_import_structure["kandinsky"] = [ _import_structure["kandinsky"] = [
"KandinskyCombinedPipeline", "KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgCombinedPipeline",
...@@ -640,6 +641,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -640,6 +641,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder, ReduxImageEncoder,
) )
from .hidream_image import HiDreamImagePipeline from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import ( from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline, HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline, HunyuanVideoFramepackPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
_import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuanimage import HunyuanImagePipeline
from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
from ...guiders import AdaptiveProjectedMixGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanImagePipeline
>>> pipe = HunyuanImagePipeline.from_pretrained(
... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
>>> image.save("hunyuanimage.png")
```
"""
def extract_glyph_text(prompt: str):
"""
Extract text enclosed in quotes for glyph rendering.
Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
Args:
prompt: Input text prompt
Returns:
Formatted glyph text string or None if no quoted text found
"""
text_prompt_texts = []
pattern_quote_single = r"\'(.*?)\'"
pattern_quote_double = r"\"(.*?)\""
pattern_quote_chinese_single = r"‘(.*?)’"
pattern_quote_chinese_double = r"“(.*?)”"
matches_quote_single = re.findall(pattern_quote_single, prompt)
matches_quote_double = re.findall(pattern_quote_double, prompt)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
text_prompt_texts.extend(matches_quote_single)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if text_prompt_texts:
glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
else:
glyph_text_formatted = None
return glyph_text_formatted
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HunyuanImagePipeline(DiffusionPipeline):
r"""
The HunyuanImage pipeline for text-to-image generation.
Args:
transformer ([`HunyuanImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanImage`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
text_encoder_2 ([`T5EncoderModel`]):
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
variant.
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
guider ([`AdaptiveProjectedMixGuidance`]):
[AdaptiveProjectedMixGuidance]to be used to guide the image generation.
ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
[AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["ocr_guider", "guider"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLHunyuanImage,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: ByT5Tokenizer,
transformer: HunyuanImageTransformer2DModel,
guider: Optional[AdaptiveProjectedMixGuidance] = None,
ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
guider=guider,
ocr_guider=ocr_guider,
)
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = 1000
self.tokenizer_2_max_length = 128
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.prompt_template_encode_start_idx = 34
self.default_sample_size = 64
def _get_qwen_prompt_embeds(
self,
tokenizer: Qwen2Tokenizer,
text_encoder: Qwen2_5_VLForConditionalGeneration,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 1000,
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
drop_idx: int = 34,
hidden_state_skip_layer: int = 2,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
prompt_embeds = prompt_embeds[:, drop_idx:]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = encoder_attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def _get_byt5_prompt_embeds(
self,
tokenizer: ByT5Tokenizer,
text_encoder: T5EncoderModel,
prompt: str,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 128,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
if isinstance(prompt, list):
raise ValueError("byt5 prompt should be a string")
elif prompt is None:
raise ValueError("byt5 prompt should not be None")
txt_tokens = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to(device)
prompt_embeds = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask.float(),
)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
batch_size: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
template=self.prompt_template_encode,
drop_idx=self.prompt_template_encode_start_idx,
)
if prompt_embeds_2 is None:
prompt_embeds_2_list = []
prompt_embeds_mask_2_list = []
glyph_texts = [extract_glyph_text(p) for p in prompt]
for glyph_text in glyph_texts:
if glyph_text is None:
glyph_text_embeds = torch.zeros(
(1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
)
glyph_text_embeds_mask = torch.zeros(
(1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
)
else:
glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
tokenizer=self.tokenizer_2,
text_encoder=self.text_encoder_2,
prompt=glyph_text,
device=device,
tokenizer_max_length=self.tokenizer_2_max_length,
)
prompt_embeds_2_list.append(glyph_text_embeds)
prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
_, seq_len_2, _ = prompt_embeds_2.shape
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
prompt_embeds_2=None,
prompt_embeds_mask_2=None,
negative_prompt_embeds_2=None,
negative_prompt_embeds_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
raise ValueError(
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
)
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
raise ValueError(
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
distilled_guidance_scale: Optional[float] = 3.25,
sigmas: Optional[List[float]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
not provided, will use an empty negative prompt. Ignored when not using guidance. ).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
distilled_guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
ignored.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, text embeddings mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
`negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
negative_prompt_embeds_2=negative_prompt_embeds_2,
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. prepare prompt embeds
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
# select guider
if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
# prompt contains ocr and pipeline has a guider for ocr
guider = self.ocr_guider
elif self.guider is not None:
guider = self.guider
# distilled model does not use guidance method, use default guider with enabled=False
else:
guider = AdaptiveProjectedMixGuidance(enabled=False)
if guider._enabled and guider.num_conditions > 1:
(
negative_prompt_embeds,
negative_prompt_embeds_mask,
negative_prompt_embeds_2,
negative_prompt_embeds_mask_2,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds_2=negative_prompt_embeds_2,
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance (for guidance-distilled model)
if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
if self.transformer.config.guidance_embeds:
guidance = (
torch.tensor(
[distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
)
* 1000.0
)
else:
guidance = None
if self.attention_kwargs is None:
self._attention_kwargs = {}
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if self.transformer.config.use_meanflow:
if i == len(timesteps) - 1:
timestep_r = torch.tensor([0.0], device=device)
else:
timestep_r = timesteps[i + 1]
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
else:
timestep_r = None
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
}
# Step 2: Update guider's internal state for this denoising step
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep,
timestep_r=timestep_r,
guidance=guidance,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HunyuanImagePipelineOutput(images=image)
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from ...guiders import AdaptiveProjectedMixGuidance
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanImageRefinerPipeline
>>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = load_image("path/to/image.png")
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
>>> image.save("hunyuanimage.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class HunyuanImageRefinerPipeline(DiffusionPipeline):
r"""
The HunyuanImage pipeline for text-to-image generation.
Args:
transformer ([`HunyuanImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanImageRefiner`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["guider"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLHunyuanImageRefiner,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
transformer: HunyuanImageTransformer2DModel,
guider: Optional[AdaptiveProjectedMixGuidance] = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
guider=guider,
)
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = 256
self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
self.prompt_template_encode_start_idx = 36
self.default_sample_size = 64
self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
def _get_qwen_prompt_embeds(
self,
tokenizer: Qwen2Tokenizer,
text_encoder: Qwen2_5_VLForConditionalGeneration,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 1000,
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
drop_idx: int = 34,
hidden_state_skip_layer: int = 2,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
prompt_embeds = prompt_embeds[:, drop_idx:]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = encoder_attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def encode_prompt(
self,
prompt: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
batch_size: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
template=self.prompt_template_encode,
drop_idx=self.prompt_template_encode_start_idx,
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
def prepare_latents(
self,
image_latents,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
strength=0.25,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, 1, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_latents = strength * noise + (1 - strength) * image_latents
return latents, cond_latents
@staticmethod
def _reorder_image_tokens(image_latents):
image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
image_latents = image_latents.permute(0, 2, 1, 3, 4)
image_latents = image_latents.reshape(
batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
)
image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
return image_latents
@staticmethod
def _restore_image_tokens_order(latents):
"""Restore image tokens order by splitting channels and removing first frame slice."""
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
latents = latents.reshape(
batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
) # B, F*2, C//2, H, W
latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
# Remove first frame slice
latents = latents[:, :, 1:]
return latents
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
image_latents = self._reorder_image_tokens(image_latents)
image_latents = image_latents * self.vae.config.scaling_factor
return image_latents
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
distilled_guidance_scale: Optional[float] = 3.25,
image: Optional[PipelineImageInput] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
sigmas: Optional[List[float]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
prompt. Ignored when not using guidance.
distilled_guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
ignored.
num_images_per_prompt (`int`, *optional*, defaults to 1):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. process image
if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
image_latents = image
else:
image = self.image_processor.preprocess(image, height, width)
image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
# 3.prepare prompt embeds
if self.guider is not None:
guider = self.guider
else:
# distilled model does not use guidance method, use default guider with enabled=False
guider = AdaptiveProjectedMixGuidance(enabled=False)
requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
if requires_unconditional_embeds:
(
negative_prompt_embeds,
negative_prompt_embeds_mask,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
# 4. Prepare latent variables
latents, cond_latents = self.prepare_latents(
image_latents=image_latents,
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=self.latent_channels,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance (this pipeline only supports guidance-distilled models)
if distilled_guidance_scale is None:
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
guidance = (
torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
* 1000.0
)
if self.attention_kwargs is None:
self._attention_kwargs = {}
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
}
# Step 2: Update guider's internal state for this denoising step
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
guidance=guidance,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
latents = self._restore_image_tokens_order(latents)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HunyuanImagePipelineOutput(images=image)
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class HunyuanImagePipelineOutput(BaseOutput):
"""
Output class for HunyuanImage pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
...@@ -76,6 +76,7 @@ LOADABLE_CLASSES = { ...@@ -76,6 +76,7 @@ LOADABLE_CLASSES = {
"SchedulerMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
"BaseGuidance": ["save_pretrained", "from_pretrained"],
}, },
"transformers": { "transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
......
...@@ -17,6 +17,21 @@ class AdaptiveProjectedGuidance(metaclass=DummyObject): ...@@ -17,6 +17,21 @@ class AdaptiveProjectedGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AdaptiveProjectedMixGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoGuidance(metaclass=DummyObject): class AutoGuidance(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -32,6 +47,21 @@ class AutoGuidance(metaclass=DummyObject): ...@@ -32,6 +47,21 @@ class AutoGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BaseGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ClassifierFreeGuidance(metaclass=DummyObject): class ClassifierFreeGuidance(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -378,6 +408,36 @@ class AutoencoderKLCosmos(metaclass=DummyObject): ...@@ -378,6 +408,36 @@ class AutoencoderKLCosmos(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanImage(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanImageRefiner(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanVideo(metaclass=DummyObject): class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -858,6 +918,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): ...@@ -858,6 +918,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class HunyuanImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject): class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -1037,6 +1037,36 @@ class HunyuanDiTPipeline(metaclass=DummyObject): ...@@ -1037,6 +1037,36 @@ class HunyuanDiTPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class HunyuanImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanImageRefinerPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject): class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import (
ByT5Tokenizer,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2Tokenizer,
T5Config,
T5EncoderModel,
)
from diffusers import (
AdaptiveProjectedMixGuidance,
AutoencoderKLHunyuanImage,
FlowMatchEulerDiscreteScheduler,
HunyuanImagePipeline,
HunyuanImageTransformer2DModel,
)
from ...testing_utils import enable_full_determinism
from ..test_pipelines_common import (
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
to_np,
)
enable_full_determinism()
class HunyuanImagePipelineFastTests(
PipelineTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanImagePipeline
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt", "negative_prompt"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
test_attention_slicing = False
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1, guidance_embeds: bool = False):
torch.manual_seed(0)
transformer = HunyuanImageTransformer2DModel(
in_channels=4,
out_channels=4,
num_attention_heads=4,
attention_head_dim=8,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=(1, 1),
guidance_embeds=guidance_embeds,
text_embed_dim=32,
text_embed_2_dim=32,
rope_axes_dim=(4, 4),
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanImage(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(32, 64, 64, 64),
layers_per_block=1,
scaling_factor=0.476986,
spatial_compression_ratio=8,
sample_size=128,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
if not guidance_embeds:
torch.manual_seed(0)
guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=2)
ocr_guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=3)
else:
guider = None
ocr_guider = None
torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 32,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [2, 2, 4],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 32,
"intermediate_size": 32,
"num_heads": 2,
"out_hidden_size": 32,
},
hidden_size=32,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
torch.manual_seed(0)
t5_config = T5Config(
d_model=32,
d_kv=4,
d_ff=16,
num_layers=2,
num_heads=2,
relative_attention_num_buckets=8,
relative_attention_max_distance=32,
vocab_size=256,
feed_forward_proj="gated-gelu",
dense_act_fn="gelu_new",
is_encoder_decoder=False,
use_cache=False,
tie_word_embeddings=False,
)
text_encoder_2 = T5EncoderModel(t5_config)
tokenizer_2 = ByT5Tokenizer()
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"guider": guider,
"ocr_guider": ocr_guider,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 5,
"height": 16,
"width": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 16, 16))
expected_slice_np = np.array(
[0.6252659, 0.51482046, 0.60799813, 0.59267783, 0.488082, 0.5857634, 0.523781, 0.58028054, 0.5674121]
)
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
self.assertTrue(
np.abs(output_slice - expected_slice_np).max() < 1e-3,
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
)
def test_inference_guider(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipe.guider = pipe.guider.new(guidance_scale=1000)
pipe.ocr_guider = pipe.ocr_guider.new(guidance_scale=1000)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 16, 16))
expected_slice_np = np.array(
[0.61494756, 0.49616697, 0.60327923, 0.6115793, 0.49047345, 0.56977504, 0.53066164, 0.58880305, 0.5570612]
)
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
self.assertTrue(
np.abs(output_slice - expected_slice_np).max() < 1e-3,
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
)
def test_inference_with_distilled_guidance(self):
device = "cpu"
components = self.get_dummy_components(guidance_embeds=True)
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["distilled_guidance_scale"] = 3.5
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 16, 16))
expected_slice_np = np.array(
[0.63667065, 0.5187377, 0.66757566, 0.6320319, 0.4913387, 0.54813194, 0.5335031, 0.5736143, 0.5461346]
)
output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
self.assertTrue(
np.abs(output_slice - expected_slice_np).max() < 1e-3,
f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(tile_sample_min_size=96)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
def test_encode_prompt_works_in_isolation(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment