"src/vscode:/vscode.git/clone" did not exist on "9c03a7da43a6fecfc5d1db978f3fd7d8f3f85ae4"
Unverified Commit 01c056f0 authored by Takuma Mori's avatar Takuma Mori Committed by GitHub
Browse files

Support ControlNet v1.1 shuffle properly (#3340)



* add inferring_controlnet_cond_batch

* Revert "add inferring_controlnet_cond_batch"

This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9.

* set guess_mode to True
whenever global_pool_conditions is True
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* nit

* add integration test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e0b56d2b
......@@ -558,7 +558,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode:
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
......
......@@ -930,6 +930,13 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
global_pool_conditions = (
self.controlnet.config.global_pool_conditions
if isinstance(self.controlnet, ControlNetModel)
else self.controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
......
......@@ -623,6 +623,37 @@ class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_image - image).max() < 1e-1
def test_v11_shuffle_global_pool_conditions(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "New York"
image = load_image(
"https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"
)
output = pipe(
prompt,
image,
generator=generator,
output_type="np",
num_inference_steps=3,
guidance_scale=7.0,
)
image = output.images[0]
assert image.shape == (512, 640, 3)
image_slice = image[-3:, -3:, -1]
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment