Commit 66e662c1 authored by bailuo's avatar bailuo
Browse files

init & optimize

parents
Pipeline #2116 failed with stages
in 0 seconds
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
from PIL import Image
import os
import numpy as np
from einops import rearrange
import torch
import torch.nn.functional as F
from torchvision import transforms
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from transformers import AutoTokenizer, PretrainedConfig
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.17.0")
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
else:
max_length = tokenizer.model_max_length
text_inputs = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
return text_inputs
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
text_input_ids = input_ids.to(text_encoder.device)
if text_encoder_use_attention_mask:
attention_mask = attention_mask.to(text_encoder.device)
else:
attention_mask = None
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
return prompt_embeds
# model_path: path of the model
# image: input image, have not been pre-processed
# save_lora_path: the path to save the lora
# prompt: the user input prompt
# lora_step: number of lora training step
# lora_lr: learning rate of lora training
# lora_rank: the rank of lora
# save_interval: the frequency of saving lora checkpoints
def train_lora(image,
prompt,
model_path,
vae_path,
save_lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank,
progress,
# lora_batch_size=1,
save_interval=-1,
):
# initialize accelerator
accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision='fp16'
)
set_seed(0)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
subfolder="tokenizer",
revision=None,
use_fast=False,
)
# initialize the model
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
text_encoder = text_encoder_cls.from_pretrained(
model_path, subfolder="text_encoder", revision=None
)
if vae_path == "default":
vae = AutoencoderKL.from_pretrained(
model_path, subfolder="vae", revision=None
)
else:
vae = AutoencoderKL.from_pretrained(vae_path)
unet = UNet2DConditionModel.from_pretrained(
model_path, subfolder="unet", revision=None
)
# set device and dtype
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
unet.to(device, dtype=torch.float16)
vae.to(device, dtype=torch.float16)
text_encoder.to(device, dtype=torch.float16)
lora_rank_list = [4,4,4,4, 8,8,8,8, 16,16,16,16, 16,16,16,16,16,16, 8,8,8,8,8,8, 4,4,4,4,4,4, 32,32]# down:4+4+4, up:6+6+6, mid:1+1
lora_rank_inx = 0
# initialize UNet LoRA
unet_lora_attn_procs = {}
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
else:
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_rank = lora_rank_list[lora_rank_inx] * 2
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
)
lora_rank_inx = lora_rank_inx + 1
unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# Optimizer creation
params_to_optimize = (unet_lora_layers.parameters())
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=lora_lr,
betas=(0.9, 0.999),
weight_decay=1e-2,
eps=1e-08,
)
lr_scheduler = get_scheduler(
"constant",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=lora_step,
num_cycles=1,
power=1.0,
)
# prepare accelerator
unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
optimizer = accelerator.prepare_optimizer(optimizer)
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
# initialize text embeddings
with torch.no_grad():
text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
text_embedding = encode_prompt(
text_encoder,
text_inputs.input_ids,
text_inputs.attention_mask,
text_encoder_use_attention_mask=False
)
text_embedding = text_embedding.repeat(lora_batch_size, 1, 1)
# initialize latent distribution
image_transforms = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for step in progress.tqdm(range(lora_step), desc="training LoRA"):
unet.train()
image_batch = []
for _ in range(lora_batch_size):
image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16)
image_transformed = image_transformed.unsqueeze(dim=0)
image_batch.append(image_transformed)
# repeat the image_transformed to enable multi-batch training
image_batch = torch.cat(image_batch, dim=0)
latents_dist = vae.encode(image_batch).latent_dist
model_input = latents_dist.sample() * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual
model_pred = unet(noisy_model_input, timesteps, text_embedding).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if save_interval > 0 and (step + 1) % save_interval == 0:
save_lora_path_intermediate = os.path.join(save_lora_path, str(step+1))
if not os.path.isdir(save_lora_path_intermediate):
os.mkdir(save_lora_path_intermediate)
# unet = unet.to(torch.float32)
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_path_intermediate,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None,
)
# unet = unet.to(torch.float16)
# save the trained lora
# unet = unet.to(torch.float32)
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
LoraLoaderMixin.save_lora_weights(
save_directory=save_lora_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None,
)
return
import copy
import os
import cv2
import numpy as np
import gradio as gr
from copy import deepcopy
from einops import rearrange
from types import SimpleNamespace
import datetime
import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
from drag_pipeline import DragPipeline
from torchvision.utils import save_image
from pytorch_lightning import seed_everything
from .drag_utils import drag_diffusion_update
from .lora_utils import train_lora
from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
from .freeu_utils import register_free_upblock2d, register_free_crossattn_upblock2d
# -------------- general UI functionality --------------
def clear_all(length=480):
return gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
[], None, None
def clear_all_gen(length=480):
return gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
gr.Image.update(value=None, height=length, width=length), \
[], None, None, None
def mask_image(image,
mask,
color=[255,0,0],
alpha=0.5):
""" Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
return out
def store_img(img, length=512):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
height,width,_ = image.shape
image = Image.fromarray(image)
image = exif_transpose(image)
image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], masked_img, mask
# once user upload an image, the original image is stored in `original_image`
# the same image is displayed in `input_image` for point clicking purpose
def store_img_gen(img):
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
image = Image.fromarray(image)
image = exif_transpose(image)
image = np.array(image)
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = image.copy()
# when new image is uploaded, `selected_points` should be empty
return image, [], masked_img, mask
# user click the image to get points, and show the points on the image
def get_points(img,
sel_pix,
evt: gr.SelectData):
# collect the selected point
sel_pix.append(evt.index)
# draw points
points = []
for idx, point in enumerate(sel_pix):
if idx % 2 == 0:
# draw a red circle at the handle point
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
else:
# draw a blue circle at the handle point
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
points.append(tuple(point))
# draw an arrow from handle point to target point
if len(points) == 2:
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
points = []
return img if isinstance(img, np.ndarray) else np.array(img)
# clear all handle/target points
def undo_points(original_image,
mask):
if mask.sum() > 0:
mask = np.uint8(mask > 0)
masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
else:
masked_img = original_image.copy()
return masked_img, []
# ------------------------------------------------------
# ----------- dragging user-input image utils -----------
def train_lora_interface(original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank,
progress=gr.Progress()):
train_lora(
original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_batch_size,
lora_rank,
progress)
return "Training LoRA Done!"
def preprocess_image(image,
device):
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
image = rearrange(image, "h w c -> 1 c h w")
image = image.to(device)
return image
def run_drag(source_image,
image_with_clicks,
mask,
prompt,
points,
inversion_strength,
end_step,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
save_dir="./results"
):
# initialize model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1)
model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model.modify_unet_forward()
# print(model)
# set vae
if vae_path != "default":
model.vae = AutoencoderKL.from_pretrained(
vae_path
).to(model.vae.device, model.vae.dtype)
# initialize parameters
seed = 42 # random seed used by a lot of people for unknown reason
seed_everything(seed)
args = SimpleNamespace()
args.prompt = prompt
args.points = points
args.n_inference_step = 50
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
args.guidance_scale = 1.0
args.unet_feature_idx = [3]
args.r_m = 1
args.r_p = 3
args.lam = lam
args.end_step = end_step
args.lr = latent_lr
args.n_pix_step = n_pix_step
full_h, full_w = source_image.shape[:2]
args.sup_res_h = int(0.5*full_h)
args.sup_res_w = int(0.5*full_w)
print(args)
source_image = preprocess_image(source_image, device)
image_with_clicks = preprocess_image(image_with_clicks, device)
# set lora
if lora_path == "":
print("applying default parameters")
model.unet.set_default_attn_processor()
else:
print("applying lora: " + lora_path)
model.unet.load_attn_procs(lora_path)
# invert the source image
# the latent code resolution is too small, only 64*64
invert_code = model.invert(source_image,
prompt,
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step)
mask = torch.from_numpy(mask).float() / 255.
mask[mask > 0.0] = 1.0
mask = rearrange(mask, "h w -> 1 1 h w").cuda()
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
handle_points = []
target_points = []
# here, the point is in x,y coordinate
for idx, point in enumerate(points):
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
cur_point = torch.round(cur_point)
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
print('handle points:', handle_points)
print('target points:', target_points)
init_code = invert_code
init_code_orig = deepcopy(init_code)
model.scheduler.set_timesteps(args.n_inference_step)
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# update according to the given supervision
updated_init_code, h_feature, h_features = drag_diffusion_update(model, init_code, t,
handle_points, target_points, mask, args)
n_move = len(h_features)
gen_img_list = []
gen_image = model(
prompt=args.prompt,
h_feature=h_feature,
end_step=args.end_step,
batch_size=2,
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
# latents=torch.cat([updated_init_code, updated_init_code], dim=0),
guidance_scale=args.guidance_scale,
num_inference_steps=args.n_inference_step,
num_actual_inference_steps=args.n_actual_inference_step
)[1].unsqueeze(dim=0)
# resize gen_image into the size of source_image
# we do this because shape of gen_image will be rounded to multipliers of 8
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
copy_gen = copy.deepcopy(gen_image)
gen_img_list.append(copy_gen)
# save the original image, user editing instructions, synthesized image
save_result = torch.cat([
source_image * 0.5 + 0.5,
torch.ones((1, 3, full_h, 25)).cuda(),
image_with_clicks * 0.5 + 0.5,
torch.ones((1, 3, full_h, 25)).cuda(),
gen_image[0:1]
], dim=-1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
save_image(gen_image, os.path.join(save_dir, save_prefix + '.png'))
#
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
out_image = (out_image * 255).astype(np.uint8)
return out_image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment