base.py 2.67 KB
Newer Older
1
import os
2
3

import cv2
4
5
import numpy as np
import torch
6
from torch.utils.data import IterableDataset
7

8
9

class Txt2ImgIterableBaseDataset(IterableDataset):
10
    """
11
    Define an interface to make the IterableDatasets for text2img data chainable
12
    """
13

14
15
16
17
18
19
20
    def __init__(self, file_path: str, rank, world_size):
        super().__init__()
        self.file_path = file_path
        self.folder_list = []
        self.file_list = []
        self.txt_list = []
        self.info = self._get_file_info(file_path)
21
22
        self.start = self.info["start"]
        self.end = self.info["end"]
23
24
25
26
27
28
29
30
31
32
33
        self.rank = rank

        self.world_size = world_size
        # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
        # self.iter_start = self.start + self.rank * self.per_worker
        # self.iter_end = min(self.iter_start + self.per_worker, self.end)
        # self.num_records = self.iter_end - self.iter_start
        # self.valid_ids = [i for i in range(self.iter_end)]
        self.num_records = self.end - self.start
        self.valid_ids = [i for i in range(self.end)]

34
        print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    def __len__(self):
        # return self.iter_end - self.iter_start
        return self.end - self.start

    def __iter__(self):
        sample_iterator = self._sample_generator(self.start, self.end)
        # sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
        return sample_iterator

    def _sample_generator(self, start, end):
        for idx in range(start, end):
            file_name = self.file_list[idx]
            txt_name = self.txt_list[idx]
49
            f_ = open(txt_name, "r")
50
51
52
53
54
            txt_ = f_.read()
            f_.close()
            image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = torch.from_numpy(image) / 255
55
            yield {"txt": txt_, "image": image}
56
57

    def _get_file_info(self, file_path):
58
        info = {
59
60
61
            "start": 1,
            "end": 0,
        }
62
        self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i]
63
        for folder in self.folder_list:
64
65
            files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i]
            txts = [k.replace("jpg", "txt") for k in files]
66
67
            self.file_list.extend(files)
            self.txt_list.extend(txts)
68
        info["end"] = len(self.file_list)
69
70
71
72
        # with open(file_path, 'r') as fin:
        #     for _ in enumerate(fin):
        #         info['end'] += 1
        # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
73
        return info