vitonhd.py 3.04 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from torch.utils.data import Dataset
from diffusers.image_processor import VaeImageProcessor
from PIL import Image, ImageEnhance
from torchvision.transforms import transforms as T
from typing import Optional

import torch
import json
import random
import numpy as np

from pathlib import Path

current_dir = Path(__file__).resolve().parent

import sys

sys.path.insert(0, str(current_dir))

from aug import DataAugment


class VITHONHD(Dataset):
    
    def __init__(self,
                 data_record_path: str,
                 height: int,
                 width: int,
                 is_train: bool = True,
                 extra_condition_key: Optional[str] = "empty",
                 data_nums: Optional[int] = None,
                 **kwargs):
        
        self.data = []
        with open(data_record_path, "r") as f:
            for line in f.readlines()[:data_nums]:
                line = json.loads(line.strip())
                self.data.append(line)

        self.height = height
        self.width = width
        
        self.is_train = is_train
        self.extra_condition_key = extra_condition_key
        
        self.totensor = T.ToTensor()
        
        self.vae_processor = VaeImageProcessor(vae_scale_factor=8)
        self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
        
        self.aug = DataAugment(**kwargs)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        person, cloth, mask = [Image.open(data[key]) for key in ['person_img_path', 'cloth_img_path', 'mask_img_path']]
        
        tmp_data = {key: img for (key, img) in zip(["person", "cloth", "mask"], [person, cloth, mask])}
        
        if self.extra_condition_key != "empty":
            tmp_data.update({self.extra_condition_key: Image.open(data[self.extra_condition_key])})
        
        if self.is_train:
            tmp_data = self.aug(tmp_data, self.extra_condition_key)
            
        return_data = {
            "person": self.vae_processor.preprocess(tmp_data['person'], self.height, self.width)[0],
            "cloth": self.vae_processor.preprocess(tmp_data['cloth'], self.height, self.width)[0],
            "mask": self.mask_processor.preprocess(tmp_data['mask'], self.height, self.width)[0],
        }
        
        # TODO: openpose, 其余处理放在外面处理
        if self.extra_condition_key != "empty":
            return_data.update({self.extra_condition_key: self.totensor(tmp_data[self.extra_condition_key].convert('L'))})
        else:
            return_data.update({self.extra_condition_key: torch.zeros_like(return_data['mask'])})
        
        if self.is_train:
            return return_data
        
        return_data.update({
            "person_ori": np.array(person.resize((self.width, self.height))),
            "mask_ori": np.array(mask.resize((self.width, self.height))),
            "name": data['person_img_path'].split("/")[-1]  # 文件名
        })
        
        return return_data