Unverified Commit 1ae15fa6 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Enhance] Update reference (#3723)



* update reference pipeline

* update reference pipeline

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 027a365a
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
...@@ -97,7 +98,14 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -97,7 +98,14 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
...@@ -130,8 +138,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -130,8 +138,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
...@@ -223,15 +231,12 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -223,15 +231,12 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
image, image,
height,
width,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
...@@ -266,6 +271,9 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -266,6 +271,9 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -274,6 +282,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -274,6 +282,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare image # 4. Prepare image
...@@ -289,6 +298,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -289,6 +298,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel): elif isinstance(controlnet, MultiControlNetModel):
images = [] images = []
...@@ -308,6 +318,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -308,6 +318,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
images.append(image_) images.append(image_)
image = images image = images
height, width = image[0].shape[-2:]
else: else:
assert False assert False
...@@ -720,14 +731,15 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli ...@@ -720,14 +731,15 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else: else:
controlnet_latent_model_input = latent_model_input control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = self.controlnet( down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image, controlnet_cond=image,
......
...@@ -9,6 +9,7 @@ from diffusers import StableDiffusionPipeline ...@@ -9,6 +9,7 @@ from diffusers import StableDiffusionPipeline
from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
...@@ -179,6 +180,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ...@@ -179,6 +180,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
attention_auto_machine_weight: float = 1.0, attention_auto_machine_weight: float = 1.0,
gn_auto_machine_weight: float = 1.0, gn_auto_machine_weight: float = 1.0,
style_fidelity: float = 0.5, style_fidelity: float = 0.5,
...@@ -248,6 +250,11 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ...@@ -248,6 +250,11 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
attention_auto_machine_weight (`float`): attention_auto_machine_weight (`float`):
Weight of using reference query for self attention's context. Weight of using reference query for self attention's context.
If attention_auto_machine_weight=1.0, use reference query for all self attention's context. If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
...@@ -295,6 +302,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ...@@ -295,6 +302,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -303,6 +313,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ...@@ -303,6 +313,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Preprocess reference image # 4. Preprocess reference image
...@@ -748,6 +759,10 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline): ...@@ -748,6 +759,10 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment