import json from os import path as osp import numpy as np from PIL import Image, ImageDraw import torch from torch.utils import data from torchvision import transforms import argparse import lightning as L from typing import Optional class VITONDataset(data.Dataset): def __init__(self, data_dir: str, dataset_list: str, height: int, width: int, semantic_nc: int): super(VITONDataset, self).__init__() self.load_height = height self.load_width = width self.semantic_nc = semantic_nc self.data_path = data_dir self.transform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # load data list img_names = [] #模特图片 c_names = [] #服装图片 with open(dataset_list, 'r') as f: for line in f.readlines(): img_name, c_name = line.strip().split() img_names.append(img_name) c_names.append(c_name) self.img_names = img_names self.c_names = dict() # self.c_names['paired'] = c_names ###img跟cloth名称相同,在不同文件夹下 self.c_names['paired'] = img_names def get_parse_agnostic(self, parse, pose_data): # parse 语义分割图 # pose_data 姿势信息 parse_array = np.array(parse) parse_upper = ((parse_array == 5).astype(np.float32) + (parse_array == 6).astype(np.float32) + (parse_array == 7).astype(np.float32)) # 这里是什么形式,应该是一张图且图中仅有这些部位 parse_neck = (parse_array == 10).astype(np.float32) r = 10 agnostic = parse.copy() # mask arms # 14表示左臂,15表示右臂 for parse_id, pose_ids in [(14, [2, 5, 6, 7]), (15, [5, 2, 3, 4])]: mask_arm = Image.new('L', (self.load_width, self.load_height), 'black') mask_arm_draw = ImageDraw.Draw(mask_arm) i_prev = pose_ids[0] for i in pose_ids[1:]: if (pose_data[i_prev, 0] == 0.0 and pose_data[i_prev, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0): continue mask_arm_draw.line([tuple(pose_data[j]) for j in [i_prev, i]], 'white', width=r*10) pointx, pointy = pose_data[i] radius = r*4 if i == pose_ids[-1] else r*15 mask_arm_draw.ellipse((pointx-radius, pointy-radius, pointx+radius, pointy+radius), 'white', 'white') i_prev = i parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32) agnostic.paste(0, None, Image.fromarray(np.uint8(parse_arm * 255), 'L')) # mask torso & neck agnostic.paste(0, None, Image.fromarray(np.uint8(parse_upper * 255), 'L')) agnostic.paste(0, None, Image.fromarray(np.uint8(parse_neck * 255), 'L')) return agnostic def get_img_agnostic(self, img, parse, pose_data): parse_array = np.array(parse) parse_head = ((parse_array == 4).astype(np.float32) + (parse_array == 13).astype(np.float32)) parse_lower = ((parse_array == 9).astype(np.float32) + (parse_array == 12).astype(np.float32) + (parse_array == 16).astype(np.float32) + (parse_array == 17).astype(np.float32) + (parse_array == 18).astype(np.float32) + (parse_array == 19).astype(np.float32)) r = 20 agnostic = img.copy() agnostic_draw = ImageDraw.Draw(agnostic) length_a = np.linalg.norm(pose_data[5] - pose_data[2]) length_b = np.linalg.norm(pose_data[12] - pose_data[9]) point = (pose_data[9] + pose_data[12]) / 2 pose_data[9] = point + (pose_data[9] - point) / length_b * length_a pose_data[12] = point + (pose_data[12] - point) / length_b * length_a # mask arms agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r*10) for i in [2, 5]: pointx, pointy = pose_data[i] agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray') for i in [3, 4, 6, 7]: if (pose_data[i - 1, 0] == 0.0 and pose_data[i - 1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0): continue agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r*10) pointx, pointy = pose_data[i] agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray') # mask torso for i in [9, 12]: pointx, pointy = pose_data[i] agnostic_draw.ellipse((pointx-r*3, pointy-r*6, pointx+r*3, pointy+r*6), 'gray', 'gray') agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r*6) agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r*6) agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r*12) agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray') # mask neck pointx, pointy = pose_data[1] agnostic_draw.rectangle((pointx-r*7, pointy-r*7, pointx+r*7, pointy+r*7), 'gray', 'gray') agnostic.paste(img, None, Image.fromarray(np.uint8(parse_head * 255), 'L')) agnostic.paste(img, None, Image.fromarray(np.uint8(parse_lower * 255), 'L')) return agnostic def __getitem__(self, index): img_name = self.img_names[index] c_name = {} c = {} # 衣物 cm = {} # 衣物的mask for key in self.c_names: c_name[key] = self.c_names[key][index] c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB') # 读取衣服图像 c[key] = transforms.Resize(self.load_width, interpolation=2)(c[key]) # 修改宽度 cm[key] = Image.open(osp.join(self.data_path, 'cloth-mask', c_name[key])) cm[key] = transforms.Resize(self.load_width, interpolation=0)(cm[key]) c[key] = self.transform(c[key]) # [-1,1] cm_array = np.array(cm[key]) cm_array = (cm_array >= 128).astype(np.float32) # 二值化 cm[key] = torch.from_numpy(cm_array) # [0,1] cm[key].unsqueeze_(0) # load pose image pose_name = img_name.replace('.jpg', '_rendered.png') pose_rgb = Image.open(osp.join(self.data_path, 'openpose_img', pose_name)) pose_rgb = transforms.Resize(self.load_width, interpolation=2)(pose_rgb) pose_rgb = self.transform(pose_rgb) # [-1,1] pose_name = img_name.replace('.jpg', '_keypoints.json') with open(osp.join(self.data_path, 'openpose_json', pose_name), 'r') as f: pose_label = json.load(f) pose_data = pose_label['people'][0]['pose_keypoints_2d'] pose_data = np.array(pose_data) pose_data = pose_data.reshape((-1, 3))[:, :2] # load parsing image 语义分割图 parse_name = img_name.replace('.jpg', '.png') parse = Image.open(osp.join(self.data_path, 'image-parse-v3', parse_name)) parse = transforms.Resize(self.load_width, interpolation=0)(parse) parse_agnostic = self.get_parse_agnostic(parse, pose_data) parse_agnostic = torch.from_numpy(np.array(parse_agnostic)[None]).long() labels = { 0: ['background', [0, 10]], 1: ['hair', [1, 2]], 2: ['face', [4, 13]], 3: ['upper', [5, 6, 7]], 4: ['bottom', [9, 12]], 5: ['left_arm', [14]], 6: ['right_arm', [15]], 7: ['left_leg', [16]], 8: ['right_leg', [17]], 9: ['left_shoe', [18]], 10: ['right_shoe', [19]], 11: ['socks', [8]], 12: ['noise', [3, 11]] } # 不同通道表示不同类别 parse_agnostic_map = torch.zeros(20, self.load_height, self.load_width, dtype=torch.float) parse_agnostic_map.scatter_(0, parse_agnostic, 1.0) new_parse_agnostic_map = torch.zeros(self.semantic_nc, self.load_height, self.load_width, dtype=torch.float) for i in range(len(labels)): for label in labels[i][1]: new_parse_agnostic_map[i] += parse_agnostic_map[label] # load person image img = Image.open(osp.join(self.data_path, 'image', img_name)) img = transforms.Resize(self.load_width, interpolation=2)(img) img_agnostic = self.get_img_agnostic(img, parse, pose_data) img = self.transform(img) img_agnostic = self.transform(img_agnostic) # [-1,1] result = { 'img_name': img_name, 'c_name': c_name, 'img': img, 'img_agnostic': img_agnostic, 'parse_agnostic': new_parse_agnostic_map, 'pose': pose_rgb, 'cloth': c, 'cloth_mask': cm, } return result def __len__(self): return len(self.img_names) class VITONDataModule(L.LightningDataModule): def __init__(self, root_dir, mode: str = "train", height: int = 512, width: int = 384, semantic_nc: int = 13, batch_size: int = 1, num_workers: int = 1): super().__init__() self.data_dir = osp.join(root_dir, mode) self.dataset_list = osp.join(root_dir, f"{mode}_pairs.txt") self.height, self.width = height, width self.semantic_nc = semantic_nc self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage=None): self.train_dataset = VITONDataset( self.data_dir, self.dataset_list, self.height, self.width, self.semantic_nc ) def train_dataloader(self): return data.DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers) if __name__ == "__main__": dm = VITONDataModule("/parastor/home/mashun/modelzoo/OOTDiffusion/datasets/VITON-HD") dm.setup() dl = dm.train_dataloader() for data in dl: print(data) exit()