Unverified Commit a138d71e authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

HunyuanImage21 (#12333)



* add hunyuanimage2.1


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent bc403988
...@@ -347,6 +347,8 @@ ...@@ -347,6 +347,8 @@
title: HiDreamImageTransformer2DModel title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d - local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel title: HunyuanDiT2DModel
- local: api/models/hunyuanimage_transformer_2d
title: HunyuanImageTransformer2DModel
- local: api/models/hunyuan_video_transformer_3d - local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d - local: api/models/latte_transformer3d
...@@ -411,6 +413,10 @@ ...@@ -411,6 +413,10 @@
title: AutoencoderKLCogVideoX title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos - local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuanimage
title: AutoencoderKLHunyuanImage
- local: api/models/autoencoder_kl_hunyuanimage_refiner
title: AutoencoderKLHunyuanImageRefiner
- local: api/models/autoencoder_kl_hunyuan_video - local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video - local: api/models/autoencoderkl_ltx_video
...@@ -620,6 +626,8 @@ ...@@ -620,6 +626,8 @@
title: ConsisID title: ConsisID
- local: api/pipelines/framepack - local: api/pipelines/framepack
title: Framepack title: Framepack
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video - local: api/pipelines/hunyuan_video
title: HunyuanVideo title: HunyuanVideo
- local: api/pipelines/i2vgenxl - local: api/pipelines/i2vgenxl
......
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImage
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImage
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImage
[[autodoc]] AutoencoderKLHunyuanImage
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImageRefiner
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImageRefiner
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImageRefiner
[[autodoc]] AutoencoderKLHunyuanImageRefiner
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HunyuanImageTransformer2DModel
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanImageTransformer2DModel
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanImageTransformer2DModel
[[autodoc]] HunyuanImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HunyuanImage2.1
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
HunyuanImage-2.1 comes in the following variants:
| model type | model id |
|:----------:|:--------:|
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
```
You can inspect the `guider` object:
```py
>>> pipe.guider
AdaptiveProjectedMixGuidance {
"_class_name": "AdaptiveProjectedMixGuidance",
"_diffusers_version": "0.36.0.dev0",
"adaptive_projected_guidance_momentum": -0.5,
"adaptive_projected_guidance_rescale": 10.0,
"adaptive_projected_guidance_scale": 10.0,
"adaptive_projected_guidance_start_step": 5,
"enabled": true,
"eta": 0.0,
"guidance_rescale": 0.0,
"guidance_scale": 3.5,
"start": 0.0,
"stop": 1.0,
"use_original_formulation": false
}
State:
step: None
num_inference_steps: None
timestep: None
count_prepared: 0
enabled: True
num_conditions: 2
momentum_buffer: None
is_apg_enabled: False
is_cfg_enabled: True
```
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
# Update the guider configuration
pipe.guider = pipe.guider.new(guidance_scale=5.0)
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
image = pipe(
prompt=prompt,
num_inference_steps=50,
height=2048,
width=2048,
).images[0]
image.save("image.png")
```
## HunyuanImage-2.1-Distilled
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
out = pipe(
prompt,
num_inference_steps=8,
distilled_guidance_scale=3.25,
height=2048,
width=2048,
generator=generator,
).images[0]
```
## HunyuanImagePipeline
[[autodoc]] HunyuanImagePipeline
- all
- __call__
## HunyuanImageRefinerPipeline
[[autodoc]] HunyuanImageRefinerPipeline
- all
- __call__
## HunyuanImagePipelineOutput
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
\ No newline at end of file
This diff is collapsed.
...@@ -149,7 +149,9 @@ else: ...@@ -149,7 +149,9 @@ else:
_import_structure["guiders"].extend( _import_structure["guiders"].extend(
[ [
"AdaptiveProjectedGuidance", "AdaptiveProjectedGuidance",
"AdaptiveProjectedMixGuidance",
"AutoGuidance", "AutoGuidance",
"BaseGuidance",
"ClassifierFreeGuidance", "ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance", "ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance", "FrequencyDecoupledGuidance",
...@@ -184,6 +186,8 @@ else: ...@@ -184,6 +186,8 @@ else:
"AutoencoderKLAllegro", "AutoencoderKLAllegro",
"AutoencoderKLCogVideoX", "AutoencoderKLCogVideoX",
"AutoencoderKLCosmos", "AutoencoderKLCosmos",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo", "AutoencoderKLLTXVideo",
"AutoencoderKLMagvit", "AutoencoderKLMagvit",
...@@ -216,6 +220,7 @@ else: ...@@ -216,6 +220,7 @@ else:
"HunyuanDiT2DControlNetModel", "HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel", "HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel", "HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel", "HunyuanVideoTransformer3DModel",
"I2VGenXLUNet", "I2VGenXLUNet",
...@@ -462,6 +467,8 @@ else: ...@@ -462,6 +467,8 @@ else:
"HunyuanDiTControlNetPipeline", "HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline", "HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
"HunyuanImagePipeline",
"HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline", "HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline", "HunyuanVideoImageToVideoPipeline",
...@@ -849,7 +856,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -849,7 +856,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .guiders import ( from .guiders import (
AdaptiveProjectedGuidance, AdaptiveProjectedGuidance,
AdaptiveProjectedMixGuidance,
AutoGuidance, AutoGuidance,
BaseGuidance,
ClassifierFreeGuidance, ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance, ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance, FrequencyDecoupledGuidance,
...@@ -880,6 +889,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -880,6 +889,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLCosmos, AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo, AutoencoderKLLTXVideo,
AutoencoderKLMagvit, AutoencoderKLMagvit,
...@@ -912,6 +923,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -912,6 +923,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel, HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel, HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel, HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
I2VGenXLUNet, I2VGenXLUNet,
...@@ -1128,6 +1140,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1128,6 +1140,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline, HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
HunyuanImagePipeline,
HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline, HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline, HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline, HunyuanVideoImageToVideoPipeline,
......
...@@ -14,28 +14,24 @@ ...@@ -14,28 +14,24 @@
from typing import Union from typing import Union
from ..utils import is_torch_available from ..utils import is_torch_available, logging
logger = logging.get_logger(__name__)
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
if is_torch_available(): if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
from .auto_guidance import AutoGuidance from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch import torch
...@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance): ...@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
...@@ -76,19 +77,14 @@ class AdaptiveProjectedGuidance(BaseGuidance): ...@@ -76,19 +77,14 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation self.use_original_formulation = use_original_formulation
self.momentum_buffer = None self.momentum_buffer = None
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0: if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None: if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
...@@ -152,6 +148,44 @@ class MomentumBuffer: ...@@ -152,6 +148,44 @@ class MomentumBuffer:
new_average = self.momentum * self.running_average new_average = self.momentum * self.running_average
self.running_average = update_value + new_average self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def normalized_guidance( def normalized_guidance(
pred_cond: torch.Tensor, pred_cond: torch.Tensor,
......
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedMixGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
used, and momentum buffer is updated).
enabled (`bool`, defaults to `True`):
Whether this guidance is enabled.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 3.5,
guidance_rescale: float = 0.0,
adaptive_projected_guidance_scale: float = 10.0,
adaptive_projected_guidance_momentum: float = -0.5,
adaptive_projected_guidance_rescale: float = 10.0,
eta: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
adaptive_projected_guidance_start_step: int = 5,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# no guidance
if not self._is_cfg_enabled():
pred = pred_cond
# CFG + update momentum buffer
elif not self._is_apg_enabled():
if self.momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
# CFG + update momentum buffer
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
# APG
elif self._is_apg_enabled():
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.adaptive_projected_guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled() or self._is_cfg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
if not self._is_cfg_enabled():
return False
is_within_range = False
if self._step is not None:
is_within_range = self._step > self.adaptive_projected_guidance_start_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
else:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
return is_within_range and not is_close
def get_state(self):
state = super().get_state()
state["momentum_buffer"] = self.momentum_buffer
state["is_apg_enabled"] = self._is_apg_enabled()
state["is_cfg_enabled"] = self._is_cfg_enabled()
return state
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def update_momentum_buffer(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
momentum_buffer: Optional[MomentumBuffer] = None,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
if momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
diff = momentum_buffer.running_average
else:
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
...@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance): ...@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers self.auto_guidance_layers = auto_guidance_layers
...@@ -132,16 +133,11 @@ class AutoGuidance(BaseGuidance): ...@@ -132,16 +133,11 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True) registry.remove_hook(name, recurse=True)
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch import torch
...@@ -27,43 +27,50 @@ if TYPE_CHECKING: ...@@ -27,43 +27,50 @@ if TYPE_CHECKING:
class ClassifierFreeGuidance(BaseGuidance): class ClassifierFreeGuidance(BaseGuidance):
""" """
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 Implements Classifier-Free Guidance (CFG) for diffusion models.
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by Reference: https://huggingface.co/papers/2207.12598
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen CFG improves generation quality and prompt adherence by jointly training models on both conditional and
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in unconditional data, then combining predictions during inference. This allows trading off between quality (high
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] guidance) and diversity (low guidance).
The intution behind the original formulation can be thought of as moving the conditional distribution estimates **Two CFG Formulations:**
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the 1. **Original formulation** (from paper):
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. ```
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
```
Moves conditional predictions further from unconditional ones.
2. **Diffusers-native formulation** (default, from Imagen paper):
```
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
```
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
quality", "watermarks"). Equivalent in theory but more intuitive.
Use `use_original_formulation=True` to switch to the original formulation.
Args: Args:
guidance_scale (`float`, defaults to `7.5`): guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and may reduce quality. Typical range: 1.0-20.0.
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`): guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
Flawed](https://huggingface.co/papers/2305.08891). to 1.0 (full rescaling).
use_original_formulation (`bool`, defaults to `False`): use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
we use the diffusers-native implementation that has been in the codebase for a long time. See diffusers-native formulation from the Imagen paper.
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`): start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts. Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
steps.
stop (`float`, defaults to `1.0`): stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops. Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
steps.
enabled (`bool`, defaults to `True`):
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
""" """
_input_predictions = ["pred_cond", "pred_uncond"] _input_predictions = ["pred_cond", "pred_uncond"]
...@@ -76,23 +83,19 @@ class ClassifierFreeGuidance(BaseGuidance): ...@@ -76,23 +83,19 @@ class ClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation self.use_original_formulation = use_original_formulation
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch import torch
...@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): ...@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation self.use_original_formulation = use_original_formulation
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if self._step < self.zero_init_steps: # YiYi Notes: add default behavior for self._enabled == False
if not self._enabled:
pred = pred_cond
elif self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond) pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled(): elif not self._is_cfg_enabled():
pred = pred_cond pred = pred_cond
......
...@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance): ...@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
stop: Union[float, List[float], Tuple[float]] = 1.0, stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data", guidance_rescale_space: str = "data",
upcast_to_double: bool = True, upcast_to_double: bool = True,
enabled: bool = True,
): ):
if not _CAN_USE_KORNIA: if not _CAN_USE_KORNIA:
raise ImportError( raise ImportError(
...@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance): ...@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
# Set start to earliest start for any freq component and stop to latest stop for any freq component # Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start) min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop) max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop) super().__init__(min_start, max_stop, enabled)
self.guidance_scales = guidance_scales self.guidance_scales = guidance_scales
self.levels = len(guidance_scales) self.levels = len(guidance_scales)
...@@ -217,16 +218,11 @@ class FrequencyDecoupledGuidance(BaseGuidance): ...@@ -217,16 +218,11 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})" f"({len(self.guidance_scales)})"
) )
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -40,7 +40,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -40,7 +40,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_input_predictions = None _input_predictions = None
_identifier_key = "__guidance_identifier__" _identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0): def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
self._start = start self._start = start
self._stop = stop self._stop = stop
self._step: int = None self._step: int = None
...@@ -48,7 +48,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -48,7 +48,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep: torch.LongTensor = None self._timestep: torch.LongTensor = None
self._count_prepared = 0 self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True self._enabled = enabled
if not (0.0 <= start < 1.0): if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.") raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
...@@ -60,6 +60,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -60,6 +60,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"`_input_predictions` must be a list of required prediction names for the guidance technique." "`_input_predictions` must be a list of required prediction names for the guidance technique."
) )
def new(self, **kwargs):
"""
Creates a copy of this guider instance, optionally with modified configuration parameters.
Args:
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
returns an exact copy with the same configuration.
Returns:
A new guider instance with the same (or updated) configuration.
Example:
```python
# Create a CFG guider
guider = ClassifierFreeGuidance(guidance_scale=3.5)
# Create an exact copy
same_guider = guider.new()
# Create a copy with different start step, keeping other config the same
new_guider = guider.new(guidance_scale=5)
```
"""
return self.__class__.from_config(self.config, **kwargs)
def disable(self): def disable(self):
self._enabled = False self._enabled = False
...@@ -72,42 +97,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -72,42 +97,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep self._timestep = timestep
self._count_prepared = 0 self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: def get_state(self) -> Dict[str, Any]:
"""
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
the __repr__ method. Returns:
`Dict[str, Any]`: A dictionary containing the current state variables including:
- step: Current inference step
- num_inference_steps: Total number of inference steps
- timestep: Current timestep tensor
- count_prepared: Number of times prepare_models has been called
- enabled: Whether the guidance is enabled
- num_conditions: Number of conditions
"""
state = {
"step": self._step,
"num_inference_steps": self._num_inference_steps,
"timestep": self._timestep,
"count_prepared": self._count_prepared,
"enabled": self._enabled,
"num_conditions": self.num_conditions,
}
return state
def __repr__(self) -> str:
""" """
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned Returns a string representation of the guidance object including both config and current state.
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from """
the values of the provided keyword arguments to this method. # Get ConfigMixin's __repr__
str_repr = super().__repr__()
Args: # Get current state
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): state = self.get_state()
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a # Format each state variable on its own line with indentation
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the state_lines = []
conditional data identifier and the second element must be the unconditional data identifier or None. for k, v in state.items():
# Convert value to string and handle multi-line values
v_str = str(v)
if "\n" in v_str:
# For multi-line values (like MomentumBuffer), indent subsequent lines
v_lines = v_str.split("\n")
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
state_lines.append(f" {k}: {v_str}")
Example: state_str = "\n".join(state_lines)
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields( return f"{str_repr}\nState:\n{state_str}"
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None: def prepare_models(self, denoiser: torch.nn.Module) -> None:
""" """
...@@ -155,8 +190,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -155,8 +190,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod @classmethod
def _prepare_batch( def _prepare_batch(
cls, cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]], data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
data: "BlockState",
tuple_index: int, tuple_index: int,
identifier: str, identifier: str,
) -> "BlockState": ) -> "BlockState":
...@@ -182,21 +216,16 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -182,21 +216,16 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
""" """
from ..modular_pipelines.modular_pipeline import BlockState from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {} data_batch = {}
for key, value in input_fields.items(): for key, value in data.items():
try: try:
if isinstance(value, str): if isinstance(value, torch.Tensor):
data_batch[key] = getattr(data, value) data_batch[key] = value
elif isinstance(value, tuple): elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index]) data_batch[key] = value[tuple_index]
else: else:
# We've already checked that value is a string or a tuple of strings with length 2 raise ValueError(f"Invalid value type: {type(value)}")
pass except ValueError:
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.") logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch) return BlockState(**data_batch)
......
...@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance): ...@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale self.skip_layer_guidance_scale = perturbed_guidance_scale
...@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance): ...@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True) registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1: if self.num_conditions == 1:
tuple_indices = [0] tuple_indices = [0]
input_predictions = ["pred_cond"] input_predictions = ["pred_cond"]
...@@ -186,8 +182,8 @@ class PerturbedAttentionGuidance(BaseGuidance): ...@@ -186,8 +182,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
tuple_indices = [0, 1, 0] tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance): ...@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale self.skip_layer_guidance_scale = skip_layer_guidance_scale
...@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance): ...@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names: for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True) registry.remove_hook(hook_name, recurse=True)
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1: if self.num_conditions == 1:
tuple_indices = [0] tuple_indices = [0]
input_predictions = ["pred_cond"] input_predictions = ["pred_cond"]
...@@ -182,8 +178,8 @@ class SkipLayerGuidance(BaseGuidance): ...@@ -182,8 +178,8 @@ class SkipLayerGuidance(BaseGuidance):
tuple_indices = [0, 1, 0] tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance): ...@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale self.seg_guidance_scale = seg_guidance_scale
...@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance): ...@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names: for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True) registry.remove_hook(hook_name, recurse=True)
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1: if self.num_conditions == 1:
tuple_indices = [0] tuple_indices = [0]
input_predictions = ["pred_cond"] input_predictions = ["pred_cond"]
...@@ -171,8 +167,8 @@ class SmoothedEnergyGuidance(BaseGuidance): ...@@ -171,8 +167,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
tuple_indices = [0, 1, 0] tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch import torch
...@@ -58,23 +58,19 @@ class TangentialClassifierFreeGuidance(BaseGuidance): ...@@ -58,23 +58,19 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False, use_original_formulation: bool = False,
start: float = 0.0, start: float = 0.0,
stop: float = 1.0, stop: float = 1.0,
enabled: bool = True,
): ):
super().__init__(start, stop) super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation self.use_original_formulation = use_original_formulation
def prepare_inputs( def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1] tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = [] data_batches = []
for i in range(self.num_conditions): for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
......
...@@ -108,6 +108,7 @@ def _register_attention_processors_metadata(): ...@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0 from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0 from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0 from ..models.transformers.transformer_wan import WanAttnProcessor2_0
...@@ -149,6 +150,14 @@ def _register_attention_processors_metadata(): ...@@ -149,6 +150,14 @@ def _register_attention_processors_metadata():
), ),
) )
# HunyuanImageAttnProcessor
AttentionProcessorRegistry.register(
model_class=HunyuanImageAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
),
)
def _register_transformer_blocks_metadata(): def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock from ..models.attention import BasicTransformerBlock
...@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata(): ...@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata():
HunyuanVideoTokenReplaceTransformerBlock, HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock, HunyuanVideoTransformerBlock,
) )
from ..models.transformers.transformer_hunyuanimage import (
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
...@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata(): ...@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata():
), ),
) )
# HunyuanImage2.1
TransformerBlockRegistry.register(
model_class=HunyuanImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanImageSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# fmt: off # fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs): def _skip_attention___ret___hidden_states(self, *args, **kwargs):
...@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid ...@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
# not sure what this is yet. # not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on # fmt: on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment