Unverified Commit fa2abfdb authored by XCL's avatar XCL Committed by GitHub
Browse files

[Tencent Hunyuan Team] Add Hunyuan-DiT ControlNet Inference (#8694)



* add controlnet support

---------
Co-authored-by: default avatarxingchaoliu <xingchaoliu@tencent.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 1d3ef67b
...@@ -257,6 +257,8 @@ ...@@ -257,6 +257,8 @@
title: PriorTransformer title: PriorTransformer
- local: api/models/controlnet - local: api/models/controlnet
title: ControlNetModel title: ControlNetModel
- local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3 - local: api/models/controlnet_sd3
title: SD3ControlNetModel title: SD3ControlNetModel
title: Models title: Models
...@@ -282,6 +284,8 @@ ...@@ -282,6 +284,8 @@
title: Consistency Models title: Consistency Models
- local: api/pipelines/controlnet - local: api/pipelines/controlnet
title: ControlNet title: ControlNet
- local: api/pipelines/controlnet_hunyuandit
title: ControlNet with Hunyuan-DiT
- local: api/pipelines/controlnet_sd3 - local: api/pipelines/controlnet_sd3
title: ControlNet with Stable Diffusion 3 title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl - local: api/pipelines/controlnet_sdxl
......
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# HunyuanDiT2DControlNetModel
HunyuanDiT2DControlNetModel is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
## Example For Loading HunyuanDiT2DControlNetModel
```py
from diffusers import HunyuanDiT2DControlNetModel
import torch
controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", torch_dtype=torch.float16)
```
## HunyuanDiT2DControlNetModel
[[autodoc]] HunyuanDiT2DControlNetModel
\ No newline at end of file
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet with Hunyuan-DiT
HunyuanDiTControlNetPipeline is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## HunyuanDiTControlNetPipeline
[[autodoc]] HunyuanDiTControlNetPipeline
- all
- __call__
<!--Copyright 2024 The HuggingFace Team. All rights reserved. <!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at the License. You may obtain a copy of the License at
......
...@@ -83,7 +83,9 @@ else: ...@@ -83,7 +83,9 @@ else:
"ControlNetModel", "ControlNetModel",
"ControlNetXSAdapter", "ControlNetXSAdapter",
"DiTTransformer2DModel", "DiTTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel", "HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"I2VGenXLUNet", "I2VGenXLUNet",
"Kandinsky3UNet", "Kandinsky3UNet",
"ModelMixin", "ModelMixin",
...@@ -234,6 +236,7 @@ else: ...@@ -234,6 +236,7 @@ else:
"BlipDiffusionPipeline", "BlipDiffusionPipeline",
"CLIPImageProjection", "CLIPImageProjection",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
"I2VGenXLPipeline", "I2VGenXLPipeline",
"IFImg2ImgPipeline", "IFImg2ImgPipeline",
...@@ -500,7 +503,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -500,7 +503,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel, ControlNetModel,
ControlNetXSAdapter, ControlNetXSAdapter,
DiTTransformer2DModel, DiTTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
I2VGenXLUNet, I2VGenXLUNet,
Kandinsky3UNet, Kandinsky3UNet,
ModelMixin, ModelMixin,
...@@ -629,6 +634,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -629,6 +634,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline, AudioLDMPipeline,
CLIPImageProjection, CLIPImageProjection,
CycleDiffusionPipeline, CycleDiffusionPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
I2VGenXLPipeline, I2VGenXLPipeline,
IFImg2ImgPipeline, IFImg2ImgPipeline,
......
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"] _import_structure["embeddings"] = ["ImageProjection"]
...@@ -75,6 +76,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -75,6 +76,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel, VQModel,
) )
from .controlnet import ControlNetModel from .controlnet import ControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection from .embeddings import ImageProjection
......
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Union
import torch
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .attention_processor import AttentionProcessor
from .controlnet import BaseOutput, Tuple, zero_module
from .embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
from .modeling_utils import ModelMixin
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class HunyuanControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
conditioning_channels: int = 3,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "gelu-approximate",
sample_size=32,
hidden_size=1152,
transformer_num_layers: int = 40,
mlp_ratio: float = 4.0,
cross_attention_dim: int = 1024,
cross_attention_dim_t5: int = 2048,
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
):
super().__init__()
self.num_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.text_embedder = PixArtAlphaTextProjection(
in_features=cross_attention_dim_t5,
hidden_size=cross_attention_dim_t5 * 4,
out_features=cross_attention_dim,
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
in_channels=in_channels,
embed_dim=hidden_size,
patch_size=patch_size,
pos_embed_type=None,
)
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
hidden_size,
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
# HunyuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=False, # always False as it is the first half of the model
)
for layer in range(transformer_num_layers // 2 - 1)
]
)
self.input_block = zero_module(nn.Linear(hidden_size, hidden_size))
for _ in range(len(self.blocks)):
controlnet_block = nn.Linear(hidden_size, hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
corresponding cross attention processor. This is strongly recommended when setting trainable attention
processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
@classmethod
def from_transformer(
cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
):
config = transformer.config
activation_fn = config.activation_fn
attention_head_dim = config.attention_head_dim
cross_attention_dim = config.cross_attention_dim
cross_attention_dim_t5 = config.cross_attention_dim_t5
hidden_size = config.hidden_size
in_channels = config.in_channels
mlp_ratio = config.mlp_ratio
num_attention_heads = config.num_attention_heads
patch_size = config.patch_size
sample_size = config.sample_size
text_len = config.text_len
text_len_t5 = config.text_len_t5
conditioning_channels = conditioning_channels
transformer_num_layers = transformer_num_layers or config.transformer_num_layers
controlnet = cls(
conditioning_channels=conditioning_channels,
transformer_num_layers=transformer_num_layers,
activation_fn=activation_fn,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
cross_attention_dim_t5=cross_attention_dim_t5,
hidden_size=hidden_size,
in_channels=in_channels,
mlp_ratio=mlp_ratio,
num_attention_heads=num_attention_heads,
patch_size=patch_size,
sample_size=sample_size,
text_len=text_len,
text_len_t5=text_len_t5,
)
if load_weights_from_transformer:
key = controlnet.load_state_dict(transformer.state_dict(), strict=False)
logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
return controlnet
def forward(
self,
hidden_states,
timestep,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DControlNetModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
controlnet_cond ( `torch.Tensor` ):
The conditioning input to ControlNet.
conditioning_scale ( `float` ):
Indicate the conditioning scale.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C
# 2. pre-process
hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))
temb = self.time_extra_emb(
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
) # [B, D]
# text projection
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
encoder_hidden_states_t5 = self.text_embedder(
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
block_res_samples = ()
for layer, block in enumerate(self.blocks):
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
) # (N, L, D)
block_res_samples = block_res_samples + (hidden_states,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:
return (controlnet_block_res_samples,)
return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
class HunyuanDiT2DMultiControlNetModel(ModelMixin):
r"""
`HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel
This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is
designed to be compatible with `HunyuanDiT2DControlNetModel`.
Args:
controlnets (`List[HunyuanDiT2DControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`HunyuanDiT2DControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states,
timestep,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DControlNetModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
controlnet_cond ( `torch.Tensor` ):
The conditioning input to ControlNet.
conditioning_scale ( `float` ):
Indicate the conditioning scale.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
block_samples = controlnet(
hidden_states=hidden_states,
timestep=timestep,
controlnet_cond=image,
conditioning_scale=scale,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
]
control_block_samples = (control_block_samples,)
return control_block_samples
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -437,6 +437,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -437,6 +437,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
image_meta_size=None, image_meta_size=None,
style=None, style=None,
image_rotary_emb=None, image_rotary_emb=None,
controlnet_block_samples=None,
return_dict=True, return_dict=True,
): ):
""" """
...@@ -491,7 +492,10 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -491,7 +492,10 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
skips = [] skips = []
for layer, block in enumerate(self.blocks): for layer, block in enumerate(self.blocks):
if layer > self.config.num_layers // 2: if layer > self.config.num_layers // 2:
skip = skips.pop() if controlnet_block_samples is not None:
skip = skips.pop() + controlnet_block_samples.pop()
else:
skip = skips.pop()
hidden_states = block( hidden_states = block(
hidden_states, hidden_states,
temb=temb, temb=temb,
...@@ -510,6 +514,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -510,6 +514,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
if layer < (self.config.num_layers // 2 - 1): if layer < (self.config.num_layers // 2 - 1):
skips.append(hidden_states) skips.append(hidden_states)
if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
raise ValueError("The number of controls is not equal to the number of skip connections.")
# final layer # final layer
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
......
...@@ -20,6 +20,7 @@ from ..utils import ( ...@@ -20,6 +20,7 @@ from ..utils import (
_dummy_objects = {} _dummy_objects = {}
_import_structure = { _import_structure = {
"controlnet": [], "controlnet": [],
"controlnet_hunyuandit": [],
"controlnet_sd3": [], "controlnet_sd3": [],
"controlnet_xs": [], "controlnet_xs": [],
"deprecated": [], "deprecated": [],
...@@ -152,6 +153,11 @@ else: ...@@ -152,6 +153,11 @@ else:
"StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLControlNetXSPipeline",
] ]
) )
_import_structure["controlnet_hunyuandit"].extend(
[
"HunyuanDiTControlNetPipeline",
]
)
_import_structure["controlnet_sd3"].extend( _import_structure["controlnet_sd3"].extend(
[ [
"StableDiffusion3ControlNetPipeline", "StableDiffusion3ControlNetPipeline",
...@@ -409,6 +415,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -409,6 +415,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
) )
from .controlnet_hunyuandit import (
HunyuanDiTControlNetPipeline,
)
from .controlnet_sd3 import ( from .controlnet_sd3 import (
StableDiffusion3ControlNetPipeline, StableDiffusion3ControlNetPipeline,
) )
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline
import torch
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
)
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.to("cuda")
from diffusers.utils import load_image
cond_image = load_image(
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
)
## You may also use English prompt as HunyuanDiT supports both English and Chinese
prompt = "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围"
# prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
image = pipe(
prompt,
height=1024,
width=1024,
control_image=cond_image,
num_inference_steps=50,
).images[0]
```
"""
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]
def map_to_standard_shapes(target_width, target_height):
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def get_resize_crop_region_for_grid(src, tgt_size):
th = tw = tgt_size
h, w = src
r = h / w
# resize
if r > 1:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class HunyuanDiTControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for English/Chinese-to-image generation using HunyuanDiT.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
ourselves)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
`sdxl-vae-fp16-fix`.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
HunyuanDiT uses a fine-tuned [bilingual CLIP].
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
transformer ([`HunyuanDiT2DModel`]):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
controlnet ([`HunyuanDiT2DControlNetModel`] or `List[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]):
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = [
"safety_checker",
"feature_extractor",
"text_encoder_2",
"tokenizer_2",
"text_encoder",
"tokenizer",
]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"prompt_embeds_2",
"negative_prompt_embeds_2",
]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: BertModel,
tokenizer: BertTokenizer,
transformer: HunyuanDiT2DModel,
scheduler: DDPMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
controlnet: Union[
HunyuanDiT2DControlNetModel,
List[HunyuanDiT2DControlNetModel],
Tuple[HunyuanDiT2DControlNetModel],
HunyuanDiT2DMultiControlNetModel,
],
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
requires_safety_checker: bool = True,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
text_encoder_2=text_encoder_2,
controlnet=controlnet,
)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
# Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt
def encode_prompt(
self,
prompt: str,
device: torch.device = None,
dtype: torch.dtype = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: Optional[int] = None,
text_encoder_index: int = 0,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5.
"""
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
if device is None:
device = self._execution_device
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = tokenizers[text_encoder_index]
text_encoder = text_encoders[text_encoder_index]
if max_sequence_length is None:
if text_encoder_index == 0:
max_length = 77
if text_encoder_index == 1:
max_length = 256
else:
max_length = max_sequence_length
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=negative_prompt_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
raise ValueError(
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
raise ValueError(
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
f" {negative_prompt_embeds_2.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
control_image: PipelineImageInput = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = (1024, 1024),
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = True,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
The target size of the image. Used to calculate the time ids.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
The top left coordinates of the crop. Used to calculate the time ids.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = int((height // 16) * 16)
width = int((width // 16) * 16)
if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
width, height = map_to_standard_shapes(width, height)
height = int(height)
width = int(width)
logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=77,
text_encoder_index=0,
)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=256,
text_encoder_index=1,
)
# 4. Prepare control image
if isinstance(self.controlnet, HunyuanDiT2DControlNetModel):
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=False,
)
height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = control_image * self.vae.config.scaling_factor
elif isinstance(self.controlnet, HunyuanDiT2DMultiControlNetModel):
control_images = []
for control_image_ in control_image:
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=False,
)
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = control_image_ * self.vae.config.scaling_factor
control_images.append(control_image_)
control_image = control_images
else:
assert False
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
)
style = torch.tensor([0], device=device)
target_size = target_size or (height, width)
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
style = torch.cat([style] * 2, dim=0)
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
batch_size * num_images_per_prompt, 1
)
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# controlnet(s) inference
control_block_samples = self.controlnet(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale,
)[0]
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
controlnet_block_samples=control_block_samples,
)[0]
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
...@@ -122,6 +122,21 @@ class DiTTransformer2DModel(metaclass=DummyObject): ...@@ -122,6 +122,21 @@ class DiTTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanDiT2DModel(metaclass=DummyObject): class HunyuanDiT2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -137,6 +152,21 @@ class HunyuanDiT2DModel(metaclass=DummyObject): ...@@ -137,6 +152,21 @@ class HunyuanDiT2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class I2VGenXLUNet(metaclass=DummyObject): class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -212,6 +212,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject): ...@@ -212,6 +212,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTPipeline(metaclass=DummyObject): class HunyuanDiTPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc and Tencent Hunyuan Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import (
AutoencoderKL,
DDPMScheduler,
HunyuanDiT2DModel,
HunyuanDiTControlNetPipeline,
)
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = HunyuanDiTControlNetPipeline
params = frozenset(
[
"prompt",
"height",
"width",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = HunyuanDiT2DModel(
sample_size=16,
num_layers=4,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
in_channels=4,
cross_attention_dim=32,
cross_attention_dim_t5=32,
pooled_projection_dim=16,
hidden_size=24,
activation_fn="gelu-approximate",
)
torch.manual_seed(0)
controlnet = HunyuanDiT2DControlNetModel(
sample_size=16,
transformer_num_layers=4,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
in_channels=4,
cross_attention_dim=32,
cross_attention_dim_t5=32,
pooled_projection_dim=16,
hidden_size=24,
activation_fn="gelu-approximate",
)
torch.manual_seed(0)
vae = AutoencoderKL()
scheduler = DDPMScheduler()
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"safety_checker": None,
"feature_extractor": None,
"controlnet": controlnet,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
control_image = randn_tensor(
(1, 3, 16, 16),
generator=generator,
device=torch.device(device),
dtype=torch.float16,
)
controlnet_conditioning_scale = 0.5
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"control_image": control_image,
"controlnet_conditioning_scale": controlnet_conditioning_scale,
}
return inputs
def test_controlnet_hunyuandit(self):
components = self.get_dummy_components()
pipe = HunyuanDiTControlNetPipeline(**components)
pipe = pipe.to(torch_device, dtype=torch.float16)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array(
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
expected_max_diff=1e-3,
)
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
pass
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
pass
def test_save_load_optional_components(self):
# TODO(YiYi) need to fix later
pass
@slow
@require_torch_gpu
class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = HunyuanDiTControlNetPipeline
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_canny(self):
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
)
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
n_prompt = ""
control_image = load_image(
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
)
output = pipe(
prompt,
negative_prompt=n_prompt,
control_image=control_image,
controlnet_conditioning_scale=0.5,
guidance_scale=5.0,
num_inference_steps=2,
output_type="np",
generator=generator,
)
image = output.images[0]
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array(
[0.43652344, 0.4399414, 0.44921875, 0.45043945, 0.45703125, 0.44873047, 0.43579102, 0.44018555, 0.42578125]
)
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
def test_pose(self):
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", torch_dtype=torch.float16
)
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "An Asian woman, dressed in a green top, wearing a purple headscarf and a purple scarf, stands in front of a blackboard. The background is the blackboard. The photo is presented in a close-up, eye-level, and centered composition, adopting a realistic photographic style"
n_prompt = ""
control_image = load_image(
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose/resolve/main/pose.jpg?download=true"
)
output = pipe(
prompt,
negative_prompt=n_prompt,
control_image=control_image,
controlnet_conditioning_scale=0.5,
guidance_scale=5.0,
num_inference_steps=2,
output_type="np",
generator=generator,
)
image = output.images[0]
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array(
[0.4091797, 0.4177246, 0.39526367, 0.4194336, 0.40356445, 0.3857422, 0.39208984, 0.40429688, 0.37451172]
)
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
def test_depth(self):
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Depth", torch_dtype=torch.float16
)
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "In the dense forest, a black and white panda sits quietly in green trees and red flowers, surrounded by mountains, rivers, and the ocean. The background is the forest in a bright environment."
n_prompt = ""
control_image = load_image(
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Depth/resolve/main/depth.jpg?download=true"
)
output = pipe(
prompt,
negative_prompt=n_prompt,
control_image=control_image,
controlnet_conditioning_scale=0.5,
guidance_scale=5.0,
num_inference_steps=2,
output_type="np",
generator=generator,
)
image = output.images[0]
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array(
[0.31982422, 0.32177734, 0.30126953, 0.3190918, 0.3100586, 0.31396484, 0.3232422, 0.33544922, 0.30810547]
)
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
def test_multi_controlnet(self):
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
)
controlnet = HunyuanDiT2DMultiControlNetModel([controlnet, controlnet])
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
n_prompt = ""
control_image = load_image(
"https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
)
output = pipe(
prompt,
negative_prompt=n_prompt,
control_image=[control_image, control_image],
controlnet_conditioning_scale=[0.25, 0.25],
guidance_scale=5.0,
num_inference_steps=2,
output_type="np",
generator=generator,
)
image = output.images[0]
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array(
[0.43652344, 0.44018555, 0.4494629, 0.44995117, 0.45654297, 0.44848633, 0.43603516, 0.4404297, 0.42626953]
)
assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment