Commit 5e0c53d0 authored by mashun1's avatar mashun1
Browse files

catvton_openpose

parents
Pipeline #1933 failed with stages
in 0 seconds
from argparse import ArgumentParser
def get_args():
parser = ArgumentParser()
# 模型
parser.add_argument("--model_root", type=str, help="inpainting模型路径")
parser.add_argument("--vae_subfolder", type=str, default="vae")
# 数据 & 加载设置
parser.add_argument("--train_data_record_path", type=str, help="数据集信息路径")
parser.add_argument("--eval_data_record_path", type=str, help="数据集信息路径")
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=384)
parser.add_argument("--max_grad_norm", default=1.0)
# 训练相关
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--weight_dtype", type=str, default="bf16")
parser.add_argument("--max_steps", type=int, default=60000)
parser.add_argument("--noise_offset", type=float, default=None)
parser.add_argument("--use_ema", action="store_true")
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--extra_condition_key", type=str, default="empty")
## 优化器参数
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--beta1", type=float, default=0.9)
parser.add_argument("--beta2", type=float, default=0.999)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--eps", type=float, default=1e-08)
## 保存设置
parser.add_argument("--logging_steps", type=int, default=5)
parser.add_argument("--output_dir", type=str, default="../checkpoints")
parser.add_argument("--checkpoint_dir", type=str)
parser.add_argument("--eval_output_dir", type=str)
parser.add_argument("--global_steps", type=int, default=0)
args = parser.parse_args()
return args
import os
import torch
from cleanfid import fid as FID
from PIL import Image
from torch.utils.data import Dataset
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchvision import transforms
from tqdm import tqdm
from pathlib import Path
from utils import scan_files_in_dir
from prettytable import PrettyTable
class EvalDataset(Dataset):
def __init__(self, gt_folder, pred_folder, height=1024):
self.gt_folder = gt_folder
self.pred_folder = pred_folder
self.height = height
self.data = self.prepare_data()
self.to_tensor = transforms.ToTensor()
def extract_id_from_filename(self, filename):
# find first number in filename
start_i = None
for i, c in enumerate(filename):
if c.isdigit():
start_i = i
break
if start_i is None:
assert False, f"Cannot find number in filename {filename}"
return filename[start_i:start_i+8]
def prepare_data(self):
gt_files = scan_files_in_dir(self.gt_folder, postfix={'.jpg', '.png'})
gt_dict = {self.extract_id_from_filename(file.name): file for file in gt_files}
pred_files = scan_files_in_dir(self.pred_folder, postfix={'.jpg', '.png'})
tuples = []
for pred_file in pred_files:
pred_id = self.extract_id_from_filename(pred_file.name)
if pred_id not in gt_dict:
print(f"Cannot find gt file for {pred_file}")
else:
tuples.append((gt_dict[pred_id].path, pred_file.path))
return tuples
def resize(self, img):
w, h = img.size
new_w = int(w * self.height / h)
return img.resize((new_w, self.height), Image.LANCZOS)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
gt_path, pred_path = self.data[idx]
gt, pred = self.resize(Image.open(gt_path)), self.resize(Image.open(pred_path))
if gt.height != self.height:
gt = self.resize(gt)
if pred.height != self.height:
pred = self.resize(pred)
gt = self.to_tensor(gt)
pred = self.to_tensor(pred)
return gt, pred
def copy_resize_gt(gt_folder, height):
# new_folder = f"{gt_folder}_{height}"
new_folder = str(Path(gt_folder).resolve().parent / f"image_{height}")
if not os.path.exists(new_folder):
os.makedirs(new_folder, exist_ok=True)
for file in tqdm(os.listdir(gt_folder)):
if os.path.exists(os.path.join(new_folder, file)):
continue
img = Image.open(os.path.join(gt_folder, file))
w, h = img.size
new_w = int(w * height / h)
img = img.resize((new_w, height), Image.LANCZOS)
img.save(os.path.join(new_folder, file))
return new_folder
@torch.no_grad()
def ssim(dataloader):
ssim_score = 0
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda")
for gt, pred in tqdm(dataloader, desc="Calculating SSIM"):
batch_size = gt.size(0)
gt, pred = gt.to("cuda"), pred.to("cuda")
ssim_score += ssim(pred, gt) * batch_size
return ssim_score / len(dataloader.dataset)
@torch.no_grad()
def lpips(dataloader):
lpips_score = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to("cuda")
score = 0
for gt, pred in tqdm(dataloader, desc="Calculating LPIPS"):
batch_size = gt.size(0)
pred = pred.to("cuda")
gt = gt.to("cuda")
# LPIPS needs the images to be in the [-1, 1] range.
gt = (gt * 2) - 1
pred = (pred * 2) - 1
score += lpips_score(gt, pred) * batch_size
return score / len(dataloader.dataset)
def eval(args):
# Check gt_folder has images with target height, resize if not
pred_sample = os.listdir(args.pred_folder)[0]
gt_sample = os.listdir(args.gt_folder)[0]
img = Image.open(os.path.join(args.pred_folder, pred_sample))
gt_img = Image.open(os.path.join(args.gt_folder, gt_sample))
if img.height != gt_img.height:
title = "--"*30 + "Resizing GT Images to height {img.height}" + "--"*30
print(title)
args.gt_folder = copy_resize_gt(args.gt_folder, img.height)
print("-"*len(title))
# Form dataset
dataset = EvalDataset(args.gt_folder, args.pred_folder, img.height)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, drop_last=False
)
# Calculate Metrics
header = []
row = []
header = ["FID", "KID"]
fid_ = FID.compute_fid(args.gt_folder, args.pred_folder)
kid_ = FID.compute_kid(args.gt_folder, args.pred_folder) * 1000
row = [fid_, kid_]
if args.paired:
header += ["SSIM", "LPIPS"]
ssim_ = ssim(dataloader).item()
lpips_ = lpips(dataloader).item()
row += [ssim_, lpips_]
# Print Results
print("GT Folder : ", args.gt_folder)
print("Pred Folder: ", args.pred_folder)
table = PrettyTable()
table.field_names = header
table.add_row(row)
print(table)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--gt_folder", type=str, required=True)
parser.add_argument("--pred_folder", type=str, required=True)
parser.add_argument("--paired", action="store_true")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_workers", type=int, default=4)
args = parser.parse_args()
eval(args)
\ No newline at end of file
# 日期: 2024/11
# 作者: 马顺
# 机构:sugon
import random
import torch
import PIL
from PIL import Image, ImageEnhance
from typing import Optional, Union, List, Dict
from torchvision import transforms as T
import torchvision.transforms.functional as TF
class DataAugment:
def __init__(self,
brightness: float = 0.5,
contrast: float = 0.3,
saturation: float = 0.5,
hue: float = 0.5,
brightness_p: float = 0.5,
contrast_p: float = 0.5,
saturation_p: float = 0.5,
hue_p: float = 0.5,
shift_p: float = 0.5,
horizontal_flip_p: float = 0.5):
self.color_jitter = T.ColorJitter(brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue)
self.brightness_p = brightness_p
self.contrast_p = contrast_p
self.saturation_p = saturation_p
self.hue_p = hue_p
self.shift_p = shift_p
self.horizontal_flip_p = horizontal_flip_p
def __call__(self, images: Dict, extra_condition_key):
fn_idx, b, c, s, h = T.ColorJitter.get_params(self.color_jitter.brightness, self.color_jitter.contrast, self.color_jitter.saturation,self.color_jitter.hue)
random_hflip = random.random()
random_brightness = random.random()
random_contrast = random.random()
random_saturation = random.random()
random_hue = random.random()
random_shift = random.random()
shift_valx = random.uniform(-0.2, 0.2)
shift_valy = random.uniform(-0.2, 0.2)
for key, image in images.items():
if key in ['person', 'cloth']:
# for person and cloth
if random_contrast < self.contrast_p:
images[key] = TF.adjust_contrast(image, c)
if random_brightness < self.brightness_p:
images[key] = TF.adjust_brightness(image, b)
if random_hue < self.hue_p:
images[key] = TF.adjust_hue(image, h)
if random_saturation < self.saturation_p:
images[key] = TF.adjust_saturation(image, s)
# for all
if random_hflip < self.horizontal_flip_p:
images[key] = TF.hflip(image)
if random_shift < self.shift_p:
# for person, mask, extra_condition
if key in ['person', 'mask', extra_condition_key]:
images[key] = TF.affine(images[key], angle=0, translate=[shift_valx*images[key].size[-1], shift_valy*images[key].size[-2]], scale=1, shear=0)
return images
from torch.utils.data import Dataset
from diffusers.image_processor import VaeImageProcessor
from PIL import Image, ImageEnhance
from torchvision.transforms import transforms as T
from typing import Optional
import torch
import json
import random
import numpy as np
from pathlib import Path
current_dir = Path(__file__).resolve().parent
import sys
sys.path.insert(0, str(current_dir))
from aug import DataAugment
class VITHONHD(Dataset):
def __init__(self,
data_record_path: str,
height: int,
width: int,
is_train: bool = True,
extra_condition_key: Optional[str] = "empty",
data_nums: Optional[int] = None,
**kwargs):
self.data = []
with open(data_record_path, "r") as f:
for line in f.readlines()[:data_nums]:
line = json.loads(line.strip())
self.data.append(line)
self.height = height
self.width = width
self.is_train = is_train
self.extra_condition_key = extra_condition_key
self.totensor = T.ToTensor()
self.vae_processor = VaeImageProcessor(vae_scale_factor=8)
self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
self.aug = DataAugment(**kwargs)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
person, cloth, mask = [Image.open(data[key]) for key in ['person_img_path', 'cloth_img_path', 'mask_img_path']]
tmp_data = {key: img for (key, img) in zip(["person", "cloth", "mask"], [person, cloth, mask])}
if self.extra_condition_key != "empty":
tmp_data.update({self.extra_condition_key: Image.open(data[self.extra_condition_key])})
if self.is_train:
tmp_data = self.aug(tmp_data, self.extra_condition_key)
return_data = {
"person": self.vae_processor.preprocess(tmp_data['person'], self.height, self.width)[0],
"cloth": self.vae_processor.preprocess(tmp_data['cloth'], self.height, self.width)[0],
"mask": self.mask_processor.preprocess(tmp_data['mask'], self.height, self.width)[0],
}
# TODO: openpose, 其余处理放在外面处理
if self.extra_condition_key != "empty":
return_data.update({self.extra_condition_key: self.totensor(tmp_data[self.extra_condition_key].convert('L'))})
else:
return_data.update({self.extra_condition_key: torch.zeros_like(return_data['mask'])})
if self.is_train:
return return_data
return_data.update({
"person_ori": np.array(person.resize((self.width, self.height))),
"mask_ori": np.array(mask.resize((self.width, self.height))),
"name": data['person_img_path'].split("/")[-1] # 文件名
})
return return_data
import os
import math
import torch
import numpy as np
from train import init_models
from data.vitonhd import VITHONHD
from model.pipeline import CatVTONPipeline
from torch.utils.data import DataLoader
from argparse import ArgumentParser
from tqdm import tqdm
from accelerate import Accelerator
from PIL import ImageFilter, Image
weight_dtype_maps = {
"no": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16
}
def repaint(person, mask, result):
_, h = result.size
kernal_size = h // 50
if kernal_size % 2 == 0:
kernal_size += 1
# mask = mask.filter(ImageFilter.GaussianBlur(kernal_size))
person_np = np.array(person)
result_np = np.array(result)
mask_np = np.array(mask) / 255
repaint_result = person_np * (1 - mask_np) + result_np * mask_np
repaint_result = Image.fromarray(repaint_result.astype(np.uint8))
return repaint_result
def get_args():
parser = ArgumentParser()
parser.add_argument("--model_root", type=str)
parser.add_argument("--data_record_path", type=str)
parser.add_argument("--vae_subfolder", type=str, default="vae")
parser.add_argument("--output_dir", type=str)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=384)
parser.add_argument("--extra_condition_key", type=str)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--weight_dtype", type=str, default="bf16")
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=0.)
parser.add_argument("--repaint", action="store_true")
parser.add_argument("--data_nums", type=int, default=None)
args = parser.parse_args()
return args
def main():
args = get_args()
accelerator = Accelerator()
device = accelerator.device
noise_scheduler, vae, unet, optimizer_state_dict = init_models(args.model_root, device=device, vae_subfolder=args.vae_subfolder)
del optimizer_state_dict
unet.eval()
vae.eval() # train is better?
vae.to(weight_dtype_maps[args.weight_dtype])
unet.to(weight_dtype_maps[args.weight_dtype])
pipeline = CatVTONPipeline(noise_scheduler, vae, unet)
datasets = VITHONHD(args.data_record_path, 512, 384, is_train=False, extra_condition_key=args.extra_condition_key, data_nums=args.data_nums)
dataloader = DataLoader(datasets, batch_size=args.batch_size, shuffle=False, num_workers=8)
dataloader = accelerator.prepare(dataloader)
progress_bar = tqdm(total=math.ceil(len(dataloader)), iterable=dataloader, disable=not accelerator.is_main_process)
output_dir = os.path.join(args.output_dir, os.path.join(args.model_root.split("/")[-1], f"cfg_{args.guidance_scale}"))
os.makedirs(output_dir, exist_ok=True)
with torch.no_grad():
for batch in progress_bar:
names = batch['name']
sample = pipeline(
image=batch['person'],
condition_image=batch['cloth'],
mask=batch['mask'],
extra_condition=batch[args.extra_condition_key],
guidance_scale=args.guidance_scale
)
for idx, name in enumerate(names):
save_path = os.path.join(output_dir, name.replace(".jpg", '.png'))
person = Image.fromarray(batch['person_ori'][idx].cpu().numpy())
mask = Image.fromarray(batch['mask_ori'][idx].cpu().numpy())
result = sample[idx]
if args.repaint:
result = repaint(person, mask, result)
result.save(save_path)
if __name__ == "__main__":
main()
from torch.nn import functional as F
import torch
class SkipAttnProcessor(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
**kwargs
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
\ No newline at end of file
import inspect
import os
from typing import Union
import PIL
import numpy as np
import torch
import tqdm
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from model.utils import get_trainable_module, init_adapter
from utils import (compute_vae_encodings, numpy_to_pil, prepare_image,
prepare_mask_image, resize_and_crop, resize_and_padding)
import torch.nn.functional as F
class CatVTONPipeline(DiffusionPipeline):
def __init__(
self,
noise_scheduler,
vae,
unet,
):
self.register_modules(
vae=vae,
unet=unet,
noise_scheduler=noise_scheduler
)
# self.vae.device = vae.device
# self.vae.dtype = self.vae.dtype
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.noise_scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.noise_scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor],
condition_image: Union[PIL.Image.Image, torch.Tensor],
mask: Union[PIL.Image.Image, torch.Tensor],
extra_condition: None,
num_inference_steps: int = 50,
guidance_scale: float = 2.5,
height: int = 1024,
width: int = 768,
generator=None,
eta=1.0,
**kwargs
):
concat_dim = -2 # FIXME: y axis concat
# Prepare inputs to Tensor
image = prepare_image(image).to(self.vae.device, dtype=self.vae.dtype)
condition_image = prepare_image(condition_image).to(self.vae.device, dtype=self.vae.dtype)
mask = prepare_mask_image(mask).to(self.vae.device, dtype=self.vae.dtype)
# Mask image
masked_image = image * (mask < 0.5)
# VAE encoding
masked_latent = compute_vae_encodings(masked_image, self.vae)
condition_latent = compute_vae_encodings(condition_image, self.vae)
mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
del image, mask, condition_image
# Concatenate latents
masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)
# if extra_condition is not None:
# extra_condition = F.interpolate(extra_condition, size=mask_latent.shape[-2:], mode="nearest").to(self.vae.device, dtype=self.vae.dtype)
# else:
# extra_condition = torch.zeros_like(mask_latent).to(self.vae.device, dtype=self.vae.dtype)
extra_condition = F.interpolate(extra_condition, size=mask_latent.shape[-2:], mode="nearest").to(self.vae.device, dtype=self.vae.dtype)
mask_latent_concat = torch.cat([mask_latent, extra_condition], dim=concat_dim)
# Prepare noise
latents = randn_tensor(
masked_latent_concat.shape,
generator=generator,
device=masked_latent_concat.device,
dtype=self.vae.dtype,
)
# Prepare timesteps
self.noise_scheduler.set_timesteps(num_inference_steps, device=self.vae.device)
timesteps = self.noise_scheduler.timesteps
latents = latents * self.noise_scheduler.init_noise_sigma
# Classifier-Free Guidance
if do_classifier_free_guidance := (guidance_scale > 1.0):
masked_latent_concat = torch.cat(
[
torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
masked_latent_concat,
]
)
mask_latent_concat = torch.cat([mask_latent_concat] * 2)
# Denoising loop
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
num_warmup_steps = (len(timesteps) - num_inference_steps * self.noise_scheduler.order)
with tqdm.tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)
# prepare the input for the inpainting model
inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1)
# predict the noise residual
noise_pred= self.unet(
inpainting_latent_model_input,
t.to(self.vae.device),
encoder_hidden_states=None, # FIXME
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.noise_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps
and (i + 1) % self.noise_scheduler.order == 0
):
progress_bar.update()
# Decode the final latents
latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents.to(self.vae.device, dtype=self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)
return image
import os
import json
import torch
from model.attn_processor import AttnProcessor2_0, SkipAttnProcessor
def init_adapter(unet,
cross_attn_cls=SkipAttnProcessor,
self_attn_cls=None,
cross_attn_dim=None,
**kwargs):
if cross_attn_dim is None:
cross_attn_dim = unet.config.cross_attention_dim
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim # self-attn
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if self_attn_cls is not None:
attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
# retain the original attn processor
attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
return adapter_modules
def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer")
try:
unet_folder = os.path.join(diffusion_model_name_or_path, "unet")
unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r"))
unet = unet_class(**unet_configs)
unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True)
except:
unet = None
return text_encoder, vae, tokenizer, unet
def attn_of_unet(unet):
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
def get_trainable_module(unet, trainable_module_name):
if trainable_module_name == "unet":
return unet
elif trainable_module_name == "transformer":
trainable_modules = torch.nn.ModuleList()
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
if hasattr(blocks, "attentions"):
trainable_modules.append(blocks.attentions)
else:
for block in blocks:
if hasattr(block, "attentions"):
trainable_modules.append(block.attentions)
return trainable_modules
elif trainable_module_name == "attention":
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
else:
raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")
#!/bin/bash
gt_folder="../datasets/test/vitonhd/image_512"
pred_folder="../test_outputs/26000_openpose/cfg_0.0"
python cal_metrics.py \
--gt_folder=${gt_folder} \
--pred_folder=${pred_folder} \
--paired
\ No newline at end of file
#!/bin/bash
model_root="/home/vtryon/catvton_v0/checkpoints/32000_openpose_bak"
data_record_path="/home/vtryon/catvton_v0/datasets/test/test_data.jsonl"
vae_subfolder="vae"
output_dir="../test_outputs"
height=512
width=384
extra_condition_key="openpose"
batch_size=16
weight_dtype="fp16"
num_inference_steps=50
guidance_scale=0
data_nums=8
accelerate launch --main_process_port=12321 generate_test_sample.py \
--model_root=${model_root} \
--data_record_path=${data_record_path} \
--vae_subfolder=${vae_subfolder} \
--output_dir=${output_dir} \
--height=${height} \
--width=${width} \
--extra_condition_key=${extra_condition_key} \
--batch_size=${batch_size} \
--weight_dtype=${weight_dtype} \
--num_inference_steps=${num_inference_steps} \
--guidance_scale=${guidance_scale} \
--repaint
#!/bin/bash
# 初始/恢复点模型路径
model_root="/home/vtryon/catvton_v0/pretrained_models/stable-diffusion-inpainting"
train_data_record_path="/home/vtryon/catvton_v0/datasets/train_data.jsonl"
eval_data_record_path="/home/vtryon/catvton_v0/datasets/eval_data.jsonl"
height=512
width=384
vae_subfolder="sd-vae-ft-mse"
weight_dtype="bf16"
max_steps=32000
logging_steps=1000
global_steps=0
batch_size=8
gradient_accumulation_steps=4
extra_condition_key="openpose"
eval_output_dir=${max_steps}_${extra_condition_key}
checkpoint_dir=${max_steps}_${extra_condition_key}
accelerate launch train.py \
--model_root=${model_root} \
--train_data_record_path=${train_data_record_path} \
--eval_data_record_path=${eval_data_record_path} \
--height=${height} \
--width=${width} \
--eval_output_dir=${eval_output_dir} \
--checkpoint_dir=${checkpoint_dir} \
--vae_subfolder=${vae_subfolder} \
--weight_dtype=${weight_dtype} \
--max_steps=${max_steps} \
--logging_steps=${logging_steps} \
--global_steps=${global_steps} \
--batch_size=${batch_size} \
--gradient_accumulation_steps=${gradient_accumulation_steps} \
--extra_condition_key=${extra_condition_key}
import os
import torch
import random
import torch.nn.functional as F
from torch.optim import AdamW
from accelerate import Accelerator
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from tqdm import tqdm
from typing import Optional
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from model.attn_processor import SkipAttnProcessor
from model.utils import init_adapter
from args import get_args
from data.vitonhd import VITHONHD
from utils import prepare_image, prepare_mask_image, compute_vae_encodings, compute_dream_and_update_latents_for_inpaint
from model.pipeline import CatVTONPipeline
def init_models(model_root: str,
weight_dtype: str = "no",
vae_subfolder: str = "vae",
device = "cpu"):
if weight_dtype == "no":
weight_dtype = torch.float32
elif weight_dtype == "fp16":
weight_dtype = torch.float16
elif weight_dtype == "bf16":
weight_dtype = torch.bfloat16
else:
raise NotImplemented
print(f"load vae from {vae_subfolder}")
vae = AutoencoderKL.from_pretrained(model_root, subfolder=vae_subfolder)
unet = UNet2DConditionModel.from_pretrained(model_root, subfolder="unet")
try:
noise_scheduler = DDIMScheduler.from_pretrained(model_root, subfolder="scheduler")
except Exception as e:
noise_scheduler = DDIMScheduler.from_pretrained(model_root, subfolder="noise_scheduler")
init_adapter(unet, cross_attn_cls=SkipAttnProcessor)
vae.to(device)
unet.to(device)
vae.requires_grad_(False)
unet.requires_grad_(False)
for name, param in unet.named_modules():
if "attn1" in name:
param.requires_grad_(True)
unet.train()
# unet.enable_gradient_checkpointing()
optimizer_path = os.path.join(model_root, "optim.pth")
if os.path.exists(optimizer_path):
optimizer_state_dict = torch.load(optimizer_path)
else:
optimizer_state_dict = None
return noise_scheduler, vae, unet, optimizer_state_dict
def train_one_step(batch,
noise_scheduler,
vae,
unet,
device,
extra_condition_key):
person = prepare_image(batch['person'])
cloth = prepare_image(batch['cloth'])
mask = prepare_mask_image(batch['mask'])
masked_person = person * (mask < 0.5)
person_latent = compute_vae_encodings(person, vae) # 加噪
masked_person_latent = compute_vae_encodings(masked_person, vae)
if random.random() < 0.15:
# for cfg
cloth_latent = torch.zeros_like(masked_person_latent).to(device).to(masked_person_latent.dtype)
else:
cloth_latent = compute_vae_encodings(cloth, vae)
mask_latent = F.interpolate(mask, size=masked_person_latent.shape[-2:], mode="nearest")
bsz = masked_person_latent.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz, )).to(device).long()
first_input_latent = torch.concat([person_latent, cloth_latent], dim=-2)
noise = torch.randn_like(first_input_latent)
noisy_first_latent = noise_scheduler.add_noise(first_input_latent, noise, timesteps)
masked_latent_concat = torch.cat([masked_person_latent, cloth_latent], dim=-2)
extra_condition = batch.get(extra_condition_key, None)
extra_condition = F.interpolate(extra_condition, size=mask_latent.shape[-2:], mode="nearest")
mask_latent_concat = torch.cat([mask_latent, extra_condition], dim=-2)
inpainting_latent_model_input = torch.cat([noisy_first_latent, mask_latent_concat, masked_latent_concat], dim=1)
noise_pred = unet(
inpainting_latent_model_input,
timesteps,
encoder_hidden_states=None,
return_dict=False
)[0]
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
return loss
def main():
args = get_args()
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.weight_dtype
)
device = accelerator.device
train_dataset = VITHONHD(args.train_data_record_path, args.height, args.width, extra_condition_key=args.extra_condition_key)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
noise_scheduler, vae, unet, optimizer_state_dict = init_models(args.model_root, device=device, vae_subfolder=args.vae_subfolder)
optimizer = AdamW(unet.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay)
if optimizer_state_dict:
print("加载优化器状态")
optimizer.load_state_dict(optimizer_state_dict)
if accelerator.is_main_process:
eval_dataset = VITHONHD(args.eval_data_record_path, args.height, args.width, is_train=False, extra_condition_key=args.extra_condition_key)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1, num_workers=args.num_workers)
else:
eval_dataloader = None
(
unet,
optimizer,
train_dataloader
) = accelerator.prepare(
unet,
optimizer,
train_dataloader
)
global_step = args.global_steps
reach_max_steps = False
progress_bar = tqdm(initial=global_step, total=args.max_steps, disable=not accelerator.is_main_process)
progress_bar.set_description("train catvton")
while True:
if reach_max_steps:
print("到达最大训练步数,停止训练")
break
avg_loss = 0.
for batch in train_dataloader:
with accelerator.accumulate(unet):
with accelerator.autocast():
loss = train_one_step(
batch,
noise_scheduler,
vae,
unet,
device,
args.extra_condition_key
)
avg_loss += loss.item()
accelerator.backward(loss)
# TODO: 需要关注
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
avg_loss = torch.tensor(avg_loss).to(device)
avg_loss = accelerator.gather(avg_loss).mean().item() / accelerator.gradient_accumulation_steps
progress_bar.update(1)
logs = {"step_loss": avg_loss, "global_steps": global_step}
progress_bar.set_postfix(**logs)
global_step += 1
avg_loss = 0.
# 验证并保存模型
if global_step % args.logging_steps == 0 or global_step >= args.max_steps:
if accelerator.is_main_process:
unwrap_unet = accelerator.unwrap_model(unet)
unwrap_unet.eval()
pipeline = CatVTONPipeline(noise_scheduler, vae, unwrap_unet)
os.makedirs(f"../eval_outputs/{args.eval_output_dir}/{global_step}", exist_ok=True)
with torch.no_grad():
for idx, batch in enumerate(eval_dataloader):
if args.extra_condition_key:
sample = pipeline(
image=batch['person'],
condition_image=batch['cloth'],
mask=batch['mask'],
extra_condition=batch[args.extra_condition_key]
)[0]
else:
sample = pipeline(
image=batch['person'],
condition_image=batch['cloth'],
mask=batch['mask']
)[0]
sample.save(f"../eval_outputs/{args.eval_output_dir}/{global_step}/{idx}.png")
save_path = os.path.join(args.output_dir, args.checkpoint_dir)
pipeline.save_pretrained(save_path)
torch.save(optimizer.state_dict(), f"{save_path}/optim.pth")
del pipeline
del unwrap_unet
torch.cuda.empty_cache()
if global_step >= args.max_steps:
reach_max_steps = True
break
if __name__ == "__main__":
main()
import os
import math
import PIL
import numpy as np
import torch
from PIL import Image
from accelerate.state import AcceleratorState
from packaging import version
import accelerate
from typing import List, Optional, Tuple, Set
from diffusers import UNet2DConditionModel, SchedulerMixin
from tqdm import tqdm
# Compute DREAM and update latents for diffusion sampling
def compute_dream_and_update_latents_for_inpaint(
unet: UNet2DConditionModel,
noise_scheduler: SchedulerMixin,
timesteps: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
target: torch.Tensor,
encoder_hidden_states: torch.Tensor,
dream_detail_preservation: float = 1.0,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
forward step without gradients.
Args:
`unet`: The state unet to use to make a prediction.
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
`timesteps`: The timesteps for the noise_scheduler to user.
`noise`: A tensor of noise in the shape of noisy_latents.
`noisy_latents`: Previously noise latents from the training loop.
`target`: The ground-truth tensor to predict after eps is removed.
`encoder_hidden_states`: Text embeddings from the text model.
`dream_detail_preservation`: A float value that indicates detail preservation level.
See reference.
Returns:
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
pred = None # b, 4, h, w
# with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # 预测噪声
noisy_latents_no_condition = noisy_latents[:, :4] # person_masked
_noisy_latents, _target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred
delta_noise = (noise - predicted_noise).detach() # 4
delta_noise.mul_(dream_lambda)
_noisy_latents = noisy_latents_no_condition.add(sqrt_one_minus_alphas_cumprod * delta_noise) # 5
_target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction")
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
_noisy_latents = torch.cat([_noisy_latents, noisy_latents[:, 4:]], dim=1)
return _noisy_latents, _target
# Prepare the input for inpainting model.
def prepare_inpainting_input(
noisy_latents: torch.Tensor,
mask_latents: torch.Tensor,
condition_latents: torch.Tensor,
enable_condition_noise: bool = True,
condition_concat_dim: int = -1,
) -> torch.Tensor:
"""
Prepare the input for inpainting model.
Args:
noisy_latents (torch.Tensor): Noisy latents.
mask_latents (torch.Tensor): Mask latents.
condition_latents (torch.Tensor): Condition latents.
enable_condition_noise (bool): Enable condition noise.
Returns:
torch.Tensor: Inpainting input.
"""
if not enable_condition_noise:
condition_latents_ = condition_latents.chunk(2, dim=condition_concat_dim)[-1]
noisy_latents = torch.cat([noisy_latents, condition_latents_], dim=condition_concat_dim)
noisy_latents = torch.cat([noisy_latents, mask_latents, condition_latents], dim=1)
return noisy_latents
# Compute VAE encodings
def compute_vae_encodings(image: torch.Tensor, vae: torch.nn.Module) -> torch.Tensor:
"""
Args:
images (torch.Tensor): image to be encoded
vae (torch.nn.Module): vae model
Returns:
torch.Tensor: latent encoding of the image
"""
# pixel_values = image.to(memory_format=torch.contiguous_format).float()
pixel_values = image.float()
pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
return model_input
# Init Accelerator
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import ProjectConfiguration
def init_accelerator(config):
accelerator_project_config = ProjectConfiguration(
project_dir=config.project_name,
logging_dir=os.path.join(config.project_name, "logs"),
)
accelerator_ddp_config = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
log_with=config.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[accelerator_ddp_config],
gradient_accumulation_steps=config.gradient_accumulation_steps,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if accelerator.is_main_process:
accelerator.init_trackers(
project_name=config.project_name,
config={
"learning_rate": config.learning_rate,
"train_batch_size": config.train_batch_size,
"image_size": f"{config.width}x{config.height}",
},
)
return accelerator
def init_weight_dtype(wight_dtype):
return {
"no": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}[wight_dtype]
def init_add_item_id(config):
return torch.tensor(
[
config.height,
config.width * 2,
0,
0,
config.height,
config.width * 2,
]
).repeat(config.train_batch_size, 1)
def prepare_eval_data(dataset_root, dataset_name, is_pair=True):
assert dataset_name in ["vitonhd", "dresscode", "farfetch"], "Unknown dataset name {}.".format(dataset_name)
if dataset_name == "vitonhd":
data_root = os.path.join(dataset_root, "VITONHD-1024", "test")
if is_pair:
keys = os.listdir(os.path.join(data_root, "Images"))
cloth_image_paths = [
os.path.join(data_root, "Images", key, key + "-0.jpg") for key in keys
]
person_image_paths = [
os.path.join(data_root, "Images", key, key + "-1.jpg") for key in keys
]
else:
# read ../test_pairs.txt
cloth_image_paths = []
person_image_paths = []
with open(
os.path.join(dataset_root, "VITONHD-1024", "test_pairs.txt"), "r"
) as f:
lines = f.readlines()
for line in lines:
cloth_image, person_image = (
line.replace(".jpg", "").strip().split(" ")
)
cloth_image_paths.append(
os.path.join(
data_root, "Images", cloth_image, cloth_image + "-0.jpg"
)
)
person_image_paths.append(
os.path.join(
data_root, "Images", person_image, person_image + "-1.jpg"
)
)
elif dataset_name == "dresscode":
data_root = os.path.join(dataset_root, "DressCode-1024")
if is_pair:
part = ["lower", "lower", "upper", "upper", "dresses", "dresses"]
ids = ["013581", "051685", "000190", "050072", "020829", "053742"]
cloth_image_paths = [
os.path.join(data_root, "Images", part[i], ids[i], ids[i] + "_1.jpg")
for i in range(len(part))
]
person_image_paths = [
os.path.join(data_root, "Images", part[i], ids[i], ids[i] + "_0.jpg")
for i in range(len(part))
]
else:
raise ValueError("DressCode dataset does not support non-pair evaluation.")
elif dataset_name == "farfetch":
data_root = os.path.join(dataset_root, "FARFETCH-1024")
cloth_image_paths = [
# TryOn
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Blouses/13732751/13732751-2.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Hoodies/14661627/14661627-4.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Vests & Tank Tops/16532697/16532697-4.jpg",
"Images/men/Pants/Loose Fit Pants/14750720/14750720-6.jpg",
# Garment Transfer
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Shirts/10889688/10889688-3.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Shorts/Leather & Faux Leather Shorts/20143338/20143338-1.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Jackets/Blazers/15541224/15541224-2.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/men/Polo Shirts/Polo Shirts/17652415/17652415-0.jpg"
# "Images/men/Jackets/Hooded Jackets/12550261/12550261-1.jpg",
# "Images/men/Shirts/Shirts/15614589/15614589-4.jpg",
# "Images/women/Dresses/Day Dresses/10372515/10372515-3.jpg",
# "Images/women/Dresses/Sundresses/18520992/18520992-4.jpg",
# "Images/women/Skirts/Asymmetric & Draped Skirts/12404908/12404908-2.jpg",
]
person_image_paths = [
# TryOn
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Blouses/13732751/13732751-0.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Hoodies/14661627/14661627-2.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Vests & Tank Tops/16532697/16532697-1.jpg",
"Images/men/Pants/Loose Fit Pants/14750720/14750720-5.jpg",
# Garment Transfer
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Shirts/10889688/10889688-1.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Shorts/Leather & Faux Leather Shorts/20143338/20143338-2.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Jackets/Blazers/15541224/15541224-0.jpg",
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/men/Polo Shirts/Polo Shirts/17652415/17652415-4.jpg",
# "Images/men/Jackets/Hooded Jackets/12550261/12550261-3.jpg",
# "Images/men/Shirts/Shirts/15614589/15614589-3.jpg",
# "Images/women/Dresses/Day Dresses/10372515/10372515-0.jpg",
# "Images/women/Dresses/Sundresses/18520992/18520992-1.jpg",
# "Images/women/Skirts/Asymmetric & Draped Skirts/12404908/12404908-1.jpg",
]
cloth_image_paths = [
os.path.join(data_root, path) for path in cloth_image_paths
]
person_image_paths = [
os.path.join(data_root, path) for path in person_image_paths
]
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")
samples = [
{
"folder": os.path.basename(os.path.dirname(cloth_image)),
"cloth": cloth_image,
"person": person_image,
}
for cloth_image, person_image in zip(
cloth_image_paths, person_image_paths
)
]
return samples
def repaint_result(result, person_image, mask_image):
result, person, mask = np.array(result), np.array(person_image), np.array(mask_image)
# expand the mask to 3 channels & to 0~1
mask = np.expand_dims(mask, axis=2)
mask = mask / 255.0
# mask for result, ~mask for person
result_ = result * mask + person * (1 - mask)
return Image.fromarray(result_.astype(np.uint8))
def prepare_image(image):
if isinstance(image, torch.Tensor):
# Batch single image
if image.ndim == 3:
image = image.unsqueeze(0)
image = image.to(dtype=torch.float32)
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image
def prepare_mask_image(mask_image):
if isinstance(mask_image, torch.Tensor):
if mask_image.ndim == 2:
# Batch and add channel dim for single mask
mask_image = mask_image.unsqueeze(0).unsqueeze(0)
elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
# Single mask, the 0'th dimension is considered to be
# the existing batch size of 1
mask_image = mask_image.unsqueeze(0)
elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
# Batch of mask, the 0'th dimension is considered to be
# the batching dimension
mask_image = mask_image.unsqueeze(1)
# Binarize mask
mask_image[mask_image < 0.5] = 0
mask_image[mask_image >= 0.5] = 1
else:
# preprocess mask
if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
mask_image = [mask_image]
if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
mask_image = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0
)
mask_image = mask_image.astype(np.float32) / 255.0
elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
mask_image[mask_image < 0.5] = 0
mask_image[mask_image >= 0.5] = 1
mask_image = torch.from_numpy(mask_image)
return mask_image
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def tensor_to_image(tensor: torch.Tensor):
"""
Converts a torch tensor to PIL Image.
"""
assert tensor.dim() == 3, "Input tensor should be 3-dimensional."
assert tensor.dtype == torch.float32, "Input tensor should be float32."
assert (
tensor.min() >= 0 and tensor.max() <= 1
), "Input tensor should be in range [0, 1]."
tensor = tensor.cpu()
tensor = tensor * 255
tensor = tensor.permute(1, 2, 0)
tensor = tensor.numpy().astype(np.uint8)
image = Image.fromarray(tensor)
return image
def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4):
"""
Concatenates images horizontally and with
"""
widths = [image.size[0] for image in images]
heights = [image.size[1] for image in images]
total_width = cols * max(widths)
total_width += divider * (cols - 1)
# `col` images each row
rows = math.ceil(len(images) / cols)
total_height = max(heights) * rows
# add divider between rows
total_height += divider * (len(heights) // cols - 1)
# all black image
concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0))
x_offset = 0
y_offset = 0
for i, image in enumerate(images):
concat_image.paste(image, (x_offset, y_offset))
x_offset += image.size[0] + divider
if (i + 1) % cols == 0:
x_offset = 0
y_offset += image.size[1] + divider
return concat_image
def read_prompt_file(prompt_file: str):
if prompt_file is not None and os.path.isfile(prompt_file):
with open(prompt_file, "r") as sample_prompt_file:
sample_prompts = sample_prompt_file.readlines()
sample_prompts = [sample_prompt.strip() for sample_prompt in sample_prompts]
else:
sample_prompts = []
return sample_prompts
def save_tensors_to_npz(tensors: torch.Tensor, paths: List[str]):
assert len(tensors) == len(paths), "Length of tensors and paths should be the same!"
for tensor, path in zip(tensors, paths):
np.savez_compressed(path, latent=tensor.cpu().numpy())
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = (
AcceleratorState().deepspeed_plugin
if accelerate.state.is_initialized()
else None
)
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
def is_xformers_available():
try:
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
print(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. "
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
return True
except ImportError:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
def resize_and_crop(image, size):
# Crop to size ratio
w, h = image.size
target_w, target_h = size
if w / h < target_w / target_h:
new_w = w
new_h = w * target_h // target_w
else:
new_h = h
new_w = h * target_w // target_h
image = image.crop(
((w - new_w) // 2, (h - new_h) // 2, (w + new_w) // 2, (h + new_h) // 2)
)
# resize
image = image.resize(size, Image.LANCZOS)
return image
def resize_and_padding(image, size):
# Padding to size ratio
w, h = image.size
target_w, target_h = size
if w / h < target_w / target_h:
new_h = target_h
new_w = w * target_h // h
else:
new_w = target_w
new_h = h * target_w // w
image = image.resize((new_w, new_h), Image.LANCZOS)
# padding
padding = Image.new("RGB", size, (255, 255, 255))
padding.paste(image, ((target_w - new_w) // 2, (target_h - new_h) // 2))
return padding
def scan_files_in_dir(directory, postfix: Set[str] = None, progress_bar: tqdm = None) -> list:
file_list = []
progress_bar = tqdm(total=0, desc=f"Scanning", ncols=100) if progress_bar is None else progress_bar
for entry in os.scandir(directory):
if entry.is_file():
if postfix is None or os.path.splitext(entry.path)[1] in postfix:
file_list.append(entry)
progress_bar.total += 1
progress_bar.update(1)
elif entry.is_dir():
file_list += scan_files_in_dir(entry.path, postfix=postfix, progress_bar=progress_bar)
return file_list
if __name__ == "__main__":
...
\ No newline at end of file
import json
import os
import shutil
def main(args):
with open(args.data_record_path, "r") as f:
for line in f.readlines():
line = json.loads(line.strip())
for key, path in line.items():
save_path = os.path.join(args.save_root, key)
os.makedirs(save_path, exist_ok=True)
shutil.copy2(path, save_path)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--data_record_path", type=str)
parser.add_argument("--save_root", type=str)
args = parser.parse_args()
main(args)
# 将数据路径放入json文件
from pathlib import Path
from glob import glob
import os
import json
import random
def make_mapping(path_list):
return {path.stem: str(path.resolve()) for path in path_list}
def prepare_json(person_image_root,
cloth_image_root,
mask_root,
extra_condition_image_root,
extra_condition_key,
eval_nums: int,
save_root: str):
person_image_root = Path(person_image_root)
cloth_image_root = Path(cloth_image_root)
mask_root = Path(mask_root)
extra_condition_image_root = Path(extra_condition_image_root)
person_image_path_list = [*person_image_root.glob("*.png"), *person_image_root.glob("*.jpg"), *person_image_root.glob("*.jpeg")]
cloth_image_path_list = [*cloth_image_root.glob("*.png"), *cloth_image_root.glob("*.jpg"), *cloth_image_root.glob("*.jpeg")]
mask_path_list = [*mask_root.glob("*.png"), *mask_root.glob("*.jpg"), *mask_root.glob("*.jpeg")]
extra_condition_image_path_list = [*extra_condition_image_root.glob("*.png"), *extra_condition_image_root.glob("*.jpg"), *extra_condition_image_root.glob("*.jpeg")]
person_image_path_mapping = make_mapping(person_image_path_list)
cloth_image_path_mapping = make_mapping(cloth_image_path_list)
mask_path_mapping = make_mapping(mask_path_list)
extra_condition_image_path_mapping = make_mapping(extra_condition_image_path_list)
keys = set(person_image_path_mapping.keys()) & set(cloth_image_path_mapping.keys()) & \
set(mask_path_mapping.keys()) & set(extra_condition_image_path_mapping.keys())
keys = list(keys)
all_index = range(len(keys))
eval_index = set(random.choices(all_index, k=eval_nums))
train_index = set(all_index) - eval_index
eval_index, train_index = list(eval_index), list(train_index)
with open(os.path.join(save_root, "train_data.jsonl"), "w") as f:
for idx in train_index:
key = keys[idx]
temp = {}
temp['person_img_path'] = person_image_path_mapping[key]
temp['cloth_img_path'] = cloth_image_path_mapping[key]
temp['mask_img_path'] = mask_path_mapping[key]
temp[extra_condition_key] = extra_condition_image_path_mapping[key]
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
with open(os.path.join(save_root, "eval_data.jsonl"), "w") as f:
for idx in eval_index:
key = keys[idx]
temp = {}
temp['person_img_path'] = person_image_path_mapping[key]
temp['cloth_img_path'] = cloth_image_path_mapping[key]
temp['mask_img_path'] = mask_path_mapping[key]
temp[extra_condition_key] = extra_condition_image_path_mapping[key]
f.write(json.dumps(temp, ensure_ascii=False) + '\n')
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--person_image_root", type=str)
parser.add_argument("--cloth_image_root", type=str)
parser.add_argument("--mask_root", type=str)
parser.add_argument("--extra_condition_image_root", type=str)
parser.add_argument("--extra_condition_key", type=str)
parser.add_argument("--eval_nums", type=int, default=8)
parser.add_argument("--save_root", type=str)
args = parser.parse_args()
prepare_json(args.person_image_root,
args.cloth_image_root,
args.mask_root,
args.extra_condition_image_root,
args.extra_condition_key,
args.eval_nums,
args.save_root)
# 提前处理图像加速训练速度
import os
from PIL import Image
from pathlib import Path
from tqdm import tqdm
def resize_image(data_root: str,
height: int = 512,
width: int = 384):
data_root = Path(data_root)
data_path_list = [*data_root.glob("*.png"), *data_root.glob("*.jpg"), *data_root.glob("*.jpeg"), *data_root.glob("*.JPEG")]
new_data_root = str(data_root) + f"_{height}"
os.makedirs(new_data_root, exist_ok=True)
for data_path in tqdm(data_path_list):
image_name = data_path.name
save_path = os.path.join(new_data_root, image_name)
image = Image.open(str(data_path))
image = image.resize((width, height), Image.LANCZOS)
image.save(save_path)
def main(args):
resize_image(args.data_root, args.height, args.width)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--data_root", type=str)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=384)
args = parser.parse_args()
main(args)
# find . -type f -name "*_rendered.png" -exec bash -c 'mv "$0" "${0/_rendered.png/.png}"' {} \;
\ No newline at end of file
from diffusers import AutoencoderKL
import torch
from PIL import Image
from diffusers.image_processor import VaeImageProcessor
device = torch.device("cuda:0")
vae = AutoencoderKL.from_pretrained("/home/catvton_train/pretrained_models/stable-diffusion-inpainting/", subfolder="sd-vae-ft-mse")
vae.to(device).to(torch.bfloat16)
vae_processor = VaeImageProcessor(vae_scale_factor=8)
img_path = "./cloth/08424_00.jpg"
image = Image.open(img_path)
image = vae_processor.preprocess(image, 512, 384)[0]
image.unsqueeze_(0)
with torch.no_grad():
image_latent = vae.encode(image.to(device).to(vae.dtype)).latent_dist.sample()
image_latent = image_latent * vae.config.scaling_factor
image_latent = image_latent * (1/vae.config.scaling_factor)
image = vae.decode(image_latent).sample
image = (image / 2 + 0.5).clamp(0,1)
image = image.permute(0, 2, 3, 1).cpu().float().numpy()
image = image[0]
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image)
image.save("test.png")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment