Unverified Commit d184291c authored by Junsheng121's avatar Junsheng121 Committed by GitHub
Browse files

null-text-inversion-pipeline-implementation (#6329)



* null-text-inversion-implementation

* edited

* edited

* edited

* edited

* edited

* edit

* makestyle

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0a0bb526
......@@ -54,6 +54,8 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
......@@ -3102,3 +3104,42 @@ output_image = PIL.Image.fromarray(output)
output_image.save("./output.png")
```
### Null-Text Inversion pipeline
This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline.
* Reference paper
```@article{hertz2022prompt,
title={Prompt-to-prompt image editing with cross attention control},
author={Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
booktitle={arXiv preprint arXiv:2208.01626},
year={2022}
```}
```py
from diffusers.schedulers import DDIMScheduler
from examples.community.pipeline_null_text_inversion import NullTextPipeline
import torch
# Load the pipeline
device = "cuda"
# Provide invert_prompt and the image for null-text optimization.
invert_prompt = "A lying cat"
input_image = "siamese.jpg"
steps = 50
# Provide prompt used for generation. Same if reconstruction
prompt = "A lying cat"
# or different if editing.
prompt = "A lying dog"
#Float32 is essential to a well optimization
model_path = "runwayml/stable-diffusion-v1-5"
scheduler = DDIMScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear")
pipeline = NullTextPipeline.from_pretrained(model_path, scheduler = scheduler, torch_dtype=torch.float32).to(device)
#Saves the inverted_latent to save time
inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon= 1e-5, num_inference_steps = steps)
pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg")
```
\ No newline at end of file
import inspect
import os
import numpy as np
import torch
import torch.nn.functional as nnf
from PIL import Image
from torch.optim.adam import Adam
from tqdm import tqdm
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
def retrieve_timesteps(
scheduler,
num_inference_steps=None,
device=None,
timesteps=None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class NullTextPipeline(StableDiffusionPipeline):
def get_noise_pred(self, latents, t, context):
latents_input = torch.cat([latents] * 2)
guidance_scale = 7.5
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
latents = self.prev_step(noise_pred, t, latents)
return latents
def get_noise_pred_single(self, latents, t, context):
noise_pred = self.unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
@torch.no_grad()
def image2latent(self, image_path):
image = Image.open(image_path).convert("RGB")
image = np.array(image)
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(self.device)
latents = self.vae.encode(image)["latent_dist"].mean
latents = latents * 0.18215
return latents
@torch.no_grad()
def latent2image(self, latents):
latents = 1 / 0.18215 * latents.detach()
image = self.vae.decode(latents)["sample"].detach()
image = self.processor.postprocess(image, output_type="pil")[0]
return image
def prev_step(self, model_output, timestep, sample):
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
return prev_sample
def next_step(self, model_output, timestep, sample):
timestep, next_timestep = (
min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999),
timestep,
)
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
return next_sample
def null_optimization(self, latents, context, num_inner_steps, epsilon):
uncond_embeddings, cond_embeddings = context.chunk(2)
uncond_embeddings_list = []
latent_cur = latents[-1]
bar = tqdm(total=num_inner_steps * self.num_inference_steps)
for i in range(self.num_inference_steps):
uncond_embeddings = uncond_embeddings.clone().detach()
uncond_embeddings.requires_grad = True
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0))
latent_prev = latents[len(latents) - i - 2]
t = self.scheduler.timesteps[i]
with torch.no_grad():
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
for j in range(num_inner_steps):
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_cond - noise_pred_uncond)
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_item = loss.item()
bar.update()
if loss_item < epsilon + i * 2e-5:
break
for j in range(j + 1, num_inner_steps):
bar.update()
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
with torch.no_grad():
context = torch.cat([uncond_embeddings, cond_embeddings])
latent_cur = self.get_noise_pred(latent_cur, t, context)
bar.close()
return uncond_embeddings_list
@torch.no_grad()
def ddim_inversion_loop(self, latent, context):
self.scheduler.set_timesteps(self.num_inference_steps)
_, cond_embeddings = context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
with torch.no_grad():
for i in range(0, self.num_inference_steps):
t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1]
noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"]
latent = self.next_step(noise_pred, t, latent)
all_latent.append(latent)
return all_latent
def get_context(self, prompt):
uncond_input = self.tokenizer(
[""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
text_input = self.tokenizer(
[prompt],
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
context = torch.cat([uncond_embeddings, text_embeddings])
return context
def invert(
self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-6, num_inference_steps=50
):
self.num_inference_steps = num_inference_steps
context = self.get_context(prompt)
latent = self.image2latent(image_path)
ddim_latents = self.ddim_inversion_loop(latent, context)
if os.path.exists(image_path + ".pt"):
uncond_embeddings = torch.load(image_path + ".pt")
else:
uncond_embeddings = self.null_optimization(ddim_latents, context, num_inner_steps, early_stop_epsilon)
uncond_embeddings = torch.stack(uncond_embeddings, 0)
torch.save(uncond_embeddings, image_path + ".pt")
return ddim_latents[-1], uncond_embeddings
@torch.no_grad()
def __call__(
self,
prompt,
uncond_embeddings,
inverted_latent,
num_inference_steps: int = 50,
timesteps=None,
guidance_scale=7.5,
negative_prompt=None,
num_images_per_prompt=1,
generator=None,
latents=None,
prompt_embeds=None,
negative_prompt_embeds=None,
output_type="pil",
):
self._guidance_scale = guidance_scale
# 0. Default height and width to unet
height = self.unet.config.sample_size * self.vae_scale_factor
width = self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hook
callback_steps = None
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameter
device = self._execution_device
# 3. Encode input prompt
prompt_embeds, _ = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
latents = inverted_latent
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=uncond_embeddings[i])["sample"]
noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)["sample"]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
progress_bar.update()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
else:
image = latents
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=[True] * image.shape[0]
)
# Offload all models
self.maybe_free_model_hooks()
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment