"vscode:/vscode.git/clone" did not exist on "164a265714abc448337f7ee4be651454cacc85cc"
Unverified Commit 79dc7df0 authored by Andrés Romero's avatar Andrés Romero Committed by GitHub
Browse files

[bug fix] Inpainting for MultiAdapter (#5922)



* bug in MultiAdapter for Inpainting

* adapter_input is a list for MultiAdapter

---------
Co-authored-by: default avatarandres <andres@hax.ai>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6031ecbd
......@@ -1470,7 +1470,15 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromS
height, width = self._default_height_width(height, width, adapter_image)
device = self._execution_device
adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
if isinstance(adapter, MultiAdapter):
adapter_input = []
for one_image in adapter_image:
one_image = _preprocess_adapter_image(one_image, height, width)
one_image = one_image.to(device=device, dtype=adapter.dtype)
adapter_input.append(one_image)
else:
adapter_input = _preprocess_adapter_image(adapter_image, height, width)
adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
original_size = original_size or (height, width)
target_size = target_size or (height, width)
......@@ -1643,7 +1651,11 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromS
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Prepare added time ids & embeddings & adapter features
adapter_input = adapter_input.type(latents.dtype)
if isinstance(adapter, MultiAdapter):
adapter_state = adapter(adapter_input, adapter_conditioning_scale)
for k, v in enumerate(adapter_state):
adapter_state[k] = v
else:
adapter_state = adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment