"tests/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "358e211cde4412c24675af3d048f2d6d4391df59"
Unverified Commit 3cfe187d authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Cleanup ControlnetXS (#7701)

* update

* update
parent 90250d9e
...@@ -22,7 +22,14 @@ from torch import FloatTensor, nn ...@@ -22,7 +22,14 @@ from torch import FloatTensor, nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, is_torch_version, logging from ..utils import BaseOutput, is_torch_version, logging
from ..utils.torch_utils import apply_freeu from ..utils.torch_utils import apply_freeu
from .attention_processor import Attention, AttentionProcessor from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .controlnet import ControlNetConditioningEmbedding from .controlnet import ControlNetConditioningEmbedding
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -869,7 +876,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -869,7 +876,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
return processors return processors
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -904,7 +911,23 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -904,7 +911,23 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
...@@ -929,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -929,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
setattr(upsample_block, "b1", b1) setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2) setattr(upsample_block, "b2", b2)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self): def disable_freeu(self):
"""Disables the FreeU mechanism.""" """Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"} freeu_keys = {"s1", "s2", "b1", "b2"}
...@@ -938,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -938,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None) setattr(upsample_block, k, None)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
...@@ -962,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -962,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -41,7 +41,6 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -41,7 +41,6 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -462,7 +461,6 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -462,7 +461,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt, prompt,
prompt_2, prompt_2,
image, image,
callback_steps,
negative_prompt=None, negative_prompt=None,
negative_prompt_2=None, negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
...@@ -474,12 +472,6 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -474,12 +472,6 @@ class StableDiffusionXLControlNetXSPipeline(
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
): ):
...@@ -749,7 +741,6 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -749,7 +741,6 @@ class StableDiffusionXLControlNetXSPipeline(
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -878,22 +869,6 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -878,22 +869,6 @@ class StableDiffusionXLControlNetXSPipeline(
returned, otherwise a `tuple` is returned containing the output images. returned, otherwise a `tuple` is returned containing the output images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
...@@ -901,7 +876,6 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -901,7 +876,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt, prompt,
prompt_2, prompt_2,
image, image,
callback_steps,
negative_prompt, negative_prompt,
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
...@@ -1089,9 +1063,6 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -1089,9 +1063,6 @@ class StableDiffusionXLControlNetXSPipeline(
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# manually for max memory savings # manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
......
...@@ -69,6 +69,13 @@ from ..test_pipelines_common import ( ...@@ -69,6 +69,13 @@ from ..test_pipelines_common import (
enable_full_determinism() enable_full_determinism()
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
# Will be run via run_test_in_subprocess # Will be run via run_test_in_subprocess
def _test_stable_diffusion_compile(in_queue, out_queue, timeout): def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
error = None error = None
...@@ -299,6 +306,31 @@ class ControlNetXSPipelineFastTests( ...@@ -299,6 +306,31 @@ class ControlNetXSPipelineFastTests(
assert out_vae_np.shape == out_np.shape assert out_vae_np.shape == out_np.shape
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment