Unverified Commit a840c39a authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] Make guiders return their inputs (#12213)

* update

* update

* apply review suggestions

* remove guider inputs

* fix tests
parent 9a7ae77a
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -92,7 +92,7 @@ class AdaptiveProjectedGuidance(BaseGuidance): ...@@ -92,7 +92,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if not self._is_apg_enabled(): if not self._is_apg_enabled():
...@@ -111,7 +111,7 @@ class AdaptiveProjectedGuidance(BaseGuidance): ...@@ -111,7 +111,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -145,7 +145,7 @@ class AutoGuidance(BaseGuidance): ...@@ -145,7 +145,7 @@ class AutoGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if not self._is_ag_enabled(): if not self._is_ag_enabled():
...@@ -158,7 +158,7 @@ class AutoGuidance(BaseGuidance): ...@@ -158,7 +158,7 @@ class AutoGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -96,7 +96,7 @@ class ClassifierFreeGuidance(BaseGuidance): ...@@ -96,7 +96,7 @@ class ClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if not self._is_cfg_enabled(): if not self._is_cfg_enabled():
...@@ -109,7 +109,7 @@ class ClassifierFreeGuidance(BaseGuidance): ...@@ -109,7 +109,7 @@ class ClassifierFreeGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -89,7 +89,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): ...@@ -89,7 +89,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if self._step < self.zero_init_steps: if self._step < self.zero_init_steps:
...@@ -109,7 +109,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): ...@@ -109,7 +109,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from ..utils import is_kornia_available from ..utils import is_kornia_available
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -230,7 +230,7 @@ class FrequencyDecoupledGuidance(BaseGuidance): ...@@ -230,7 +230,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if not self._is_fdg_enabled(): if not self._is_fdg_enabled():
...@@ -277,7 +277,7 @@ class FrequencyDecoupledGuidance(BaseGuidance): ...@@ -277,7 +277,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0: if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0]) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -20,7 +20,7 @@ from huggingface_hub.utils import validate_hf_hub_args ...@@ -20,7 +20,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self from typing_extensions import Self
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..utils import PushToHubMixin, get_logger from ..utils import BaseOutput, PushToHubMixin, get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -284,6 +284,12 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -284,6 +284,12 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
class GuiderOutput(BaseOutput):
pred: torch.Tensor
pred_cond: Optional[torch.Tensor]
pred_uncond: Optional[torch.Tensor]
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r""" r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
......
...@@ -21,7 +21,7 @@ from ..configuration_utils import register_to_config ...@@ -21,7 +21,7 @@ from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger from ..utils import get_logger
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -197,7 +197,7 @@ class PerturbedAttentionGuidance(BaseGuidance): ...@@ -197,7 +197,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
pred_cond: torch.Tensor, pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None, pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None, pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> GuiderOutput:
pred = None pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled(): if not self._is_cfg_enabled() and not self._is_slg_enabled():
...@@ -219,7 +219,7 @@ class PerturbedAttentionGuidance(BaseGuidance): ...@@ -219,7 +219,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -192,7 +192,7 @@ class SkipLayerGuidance(BaseGuidance): ...@@ -192,7 +192,7 @@ class SkipLayerGuidance(BaseGuidance):
pred_cond: torch.Tensor, pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None, pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None, pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> GuiderOutput:
pred = None pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled(): if not self._is_cfg_enabled() and not self._is_slg_enabled():
...@@ -214,7 +214,7 @@ class SkipLayerGuidance(BaseGuidance): ...@@ -214,7 +214,7 @@ class SkipLayerGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from ..hooks import HookRegistry from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -181,7 +181,7 @@ class SmoothedEnergyGuidance(BaseGuidance): ...@@ -181,7 +181,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
pred_cond: torch.Tensor, pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None, pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None, pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> GuiderOutput:
pred = None pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled(): if not self._is_cfg_enabled() and not self._is_seg_enabled():
...@@ -203,7 +203,7 @@ class SmoothedEnergyGuidance(BaseGuidance): ...@@ -203,7 +203,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from ..configuration_utils import register_to_config from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -78,7 +78,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance): ...@@ -78,7 +78,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
if not self._is_tcfg_enabled(): if not self._is_tcfg_enabled():
...@@ -89,7 +89,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance): ...@@ -89,7 +89,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
if self.guidance_rescale > 0.0: if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {} return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property @property
def is_conditional(self) -> bool: def is_conditional(self) -> bool:
......
...@@ -238,7 +238,7 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks): ...@@ -238,7 +238,7 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
components.guider.cleanup_models(components.unet) components.guider.cleanup_models(components.unet)
# Perform guidance # Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state return components, block_state
...@@ -433,7 +433,7 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks): ...@@ -433,7 +433,7 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
components.guider.cleanup_models(components.unet) components.guider.cleanup_models(components.unet)
# Perform guidance # Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state return components, block_state
...@@ -492,7 +492,6 @@ class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks): ...@@ -492,7 +492,6 @@ class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
t, t,
block_state.latents, block_state.latents,
**block_state.extra_step_kwargs, **block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -590,7 +589,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks): ...@@ -590,7 +589,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
t, t,
block_state.latents, block_state.latents,
**block_state.extra_step_kwargs, **block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -127,7 +127,7 @@ class WanLoopDenoiser(ModularPipelineBlocks): ...@@ -127,7 +127,7 @@ class WanLoopDenoiser(ModularPipelineBlocks):
components.guider.cleanup_models(components.transformer) components.guider.cleanup_models(components.transformer)
# Perform guidance # Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state return components, block_state
...@@ -171,7 +171,6 @@ class WanLoopAfterDenoiser(ModularPipelineBlocks): ...@@ -171,7 +171,6 @@ class WanLoopAfterDenoiser(ModularPipelineBlocks):
block_state.noise_pred.float(), block_state.noise_pred.float(),
t, t,
block_state.latents.float(), block_state.latents.float(),
**block_state.scheduler_step_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment