from torch.utils.data import Dataset from diffusers.image_processor import VaeImageProcessor from PIL import Image, ImageEnhance from torchvision.transforms import transforms as T from typing import Optional import torch import json import random import numpy as np from pathlib import Path current_dir = Path(__file__).resolve().parent import sys sys.path.insert(0, str(current_dir)) from aug import DataAugment class VITHONHD(Dataset): def __init__(self, data_record_path: str, height: int, width: int, is_train: bool = True, extra_condition_key: Optional[str] = "empty", data_nums: Optional[int] = None, **kwargs): self.data = [] with open(data_record_path, "r") as f: for line in f.readlines()[:data_nums]: line = json.loads(line.strip()) self.data.append(line) self.height = height self.width = width self.is_train = is_train self.extra_condition_key = extra_condition_key self.totensor = T.ToTensor() self.vae_processor = VaeImageProcessor(vae_scale_factor=8) self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True) self.aug = DataAugment(**kwargs) def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] person, cloth, mask = [Image.open(data[key]) for key in ['person_img_path', 'cloth_img_path', 'mask_img_path']] tmp_data = {key: img for (key, img) in zip(["person", "cloth", "mask"], [person, cloth, mask])} if self.extra_condition_key != "empty": tmp_data.update({self.extra_condition_key: Image.open(data[self.extra_condition_key])}) if self.is_train: tmp_data = self.aug(tmp_data, self.extra_condition_key) return_data = { "person": self.vae_processor.preprocess(tmp_data['person'], self.height, self.width)[0], "cloth": self.vae_processor.preprocess(tmp_data['cloth'], self.height, self.width)[0], "mask": self.mask_processor.preprocess(tmp_data['mask'], self.height, self.width)[0], } # TODO: openpose, 其余处理放在外面处理 if self.extra_condition_key != "empty": return_data.update({self.extra_condition_key: self.totensor(tmp_data[self.extra_condition_key].convert('L'))}) else: return_data.update({self.extra_condition_key: torch.zeros_like(return_data['mask'])}) if self.is_train: return return_data return_data.update({ "person_ori": np.array(person.resize((self.width, self.height))), "mask_ori": np.array(mask.resize((self.width, self.height))), "name": data['person_img_path'].split("/")[-1] # 文件名 }) return return_data