import copy import os import cv2 import numpy as np import gradio as gr from copy import deepcopy from einops import rearrange from types import SimpleNamespace import datetime import PIL from PIL import Image from PIL.ImageOps import exif_transpose import torch import torch.nn.functional as F from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler from drag_pipeline import DragPipeline from torchvision.utils import save_image from pytorch_lightning import seed_everything from .drag_utils import drag_diffusion_update from .lora_utils import train_lora from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl from .freeu_utils import register_free_upblock2d, register_free_crossattn_upblock2d # -------------- general UI functionality -------------- def clear_all(length=480): return gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ [], None, None def clear_all_gen(length=480): return gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ [], None, None, None def mask_image(image, mask, color=[255,0,0], alpha=0.5): """ Overlay mask on image for visualization purpose. Args: image (H, W, 3) or (H, W): input image mask (H, W): mask to be overlaid color: the color of overlaid mask alpha: the transparency of the mask """ out = deepcopy(image) img = deepcopy(image) img[mask == 1] = color out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) return out def store_img(img, length=512): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. height,width,_ = image.shape image = Image.fromarray(image) image = exif_transpose(image) image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], masked_img, mask # once user upload an image, the original image is stored in `original_image` # the same image is displayed in `input_image` for point clicking purpose def store_img_gen(img): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. image = Image.fromarray(image) image = exif_transpose(image) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], masked_img, mask # user click the image to get points, and show the points on the image def get_points(img, sel_pix, evt: gr.SelectData): # collect the selected point sel_pix.append(evt.index) # draw points points = [] for idx, point in enumerate(sel_pix): if idx % 2 == 0: # draw a red circle at the handle point cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) else: # draw a blue circle at the handle point cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) points.append(tuple(point)) # draw an arrow from handle point to target point if len(points) == 2: cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) points = [] return img if isinstance(img, np.ndarray) else np.array(img) # clear all handle/target points def undo_points(original_image, mask): if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = original_image.copy() return masked_img, [] # ------------------------------------------------------ # ----------- dragging user-input image utils ----------- def train_lora_interface(original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_batch_size, lora_rank, progress=gr.Progress()): train_lora( original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_batch_size, lora_rank, progress) return "Training LoRA Done!" def preprocess_image(image, device): image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] image = rearrange(image, "h w c -> 1 c h w") image = image.to(device) return image def run_drag(source_image, image_with_clicks, mask, prompt, points, inversion_strength, end_step, lam, latent_lr, n_pix_step, model_path, vae_path, lora_path, start_step, start_layer, save_dir="./results" ): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # print(model) # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # initialize parameters seed = 42 # random seed used by a lot of people for unknown reason seed_everything(seed) args = SimpleNamespace() args.prompt = prompt args.points = points args.n_inference_step = 50 args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) args.guidance_scale = 1.0 args.unet_feature_idx = [3] args.r_m = 1 args.r_p = 3 args.lam = lam args.end_step = end_step args.lr = latent_lr args.n_pix_step = n_pix_step full_h, full_w = source_image.shape[:2] args.sup_res_h = int(0.5*full_h) args.sup_res_w = int(0.5*full_w) print(args) source_image = preprocess_image(source_image, device) image_with_clicks = preprocess_image(image_with_clicks, device) # set lora if lora_path == "": print("applying default parameters") model.unet.set_default_attn_processor() else: print("applying lora: " + lora_path) model.unet.load_attn_procs(lora_path) # invert the source image # the latent code resolution is too small, only 64*64 invert_code = model.invert(source_image, prompt, guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step) mask = torch.from_numpy(mask).float() / 255. mask[mask > 0.0] = 1.0 mask = rearrange(mask, "h w -> 1 1 h w").cuda() mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") handle_points = [] target_points = [] # here, the point is in x,y coordinate for idx, point in enumerate(points): cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) cur_point = torch.round(cur_point) if idx % 2 == 0: handle_points.append(cur_point) else: target_points.append(cur_point) print('handle points:', handle_points) print('target points:', target_points) init_code = invert_code init_code_orig = deepcopy(init_code) model.scheduler.set_timesteps(args.n_inference_step) t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] # update according to the given supervision updated_init_code, h_feature, h_features = drag_diffusion_update(model, init_code, t, handle_points, target_points, mask, args) n_move = len(h_features) gen_img_list = [] gen_image = model( prompt=args.prompt, h_feature=h_feature, end_step=args.end_step, batch_size=2, latents=torch.cat([init_code_orig, updated_init_code], dim=0), # latents=torch.cat([updated_init_code, updated_init_code], dim=0), guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step )[1].unsqueeze(dim=0) # resize gen_image into the size of source_image # we do this because shape of gen_image will be rounded to multipliers of 8 gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') copy_gen = copy.deepcopy(gen_image) gen_img_list.append(copy_gen) # save the original image, user editing instructions, synthesized image save_result = torch.cat([ source_image * 0.5 + 0.5, torch.ones((1, 3, full_h, 25)).cuda(), image_with_clicks * 0.5 + 0.5, torch.ones((1, 3, full_h, 25)).cuda(), gen_image[0:1] ], dim=-1) if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(gen_image, os.path.join(save_dir, save_prefix + '.png')) # out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image