"vscode:/vscode.git/clone" did not exist on "e874ef045b074b4a61d1f44aa0922b2de4706f92"
Unverified Commit 3cfe187d authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Cleanup ControlnetXS (#7701)

* update

* update
parent 90250d9e
......@@ -22,7 +22,14 @@ from torch import FloatTensor, nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, is_torch_version, logging
from ..utils.torch_utils import apply_freeu
from .attention_processor import Attention, AttentionProcessor
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .controlnet import ControlNetConditioningEmbedding
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
......@@ -869,7 +876,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
return processors
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
......@@ -904,7 +911,23 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
......@@ -929,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
......@@ -938,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
......@@ -962,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
......
......@@ -41,7 +41,6 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -462,7 +461,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
......@@ -474,12 +472,6 @@ class StableDiffusionXLControlNetXSPipeline(
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
......@@ -749,7 +741,6 @@ class StableDiffusionXLControlNetXSPipeline(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
......@@ -878,22 +869,6 @@ class StableDiffusionXLControlNetXSPipeline(
returned, otherwise a `tuple` is returned containing the output images.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
......@@ -901,7 +876,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
......@@ -1089,9 +1063,6 @@ class StableDiffusionXLControlNetXSPipeline(
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
......
......@@ -69,6 +69,13 @@ from ..test_pipelines_common import (
enable_full_determinism()
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
# Will be run via run_test_in_subprocess
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
error = None
......@@ -299,6 +306,31 @@ class ControlNetXSPipelineFastTests(
assert out_vae_np.shape == out_np.shape
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment