base.py 2.76 KB
Newer Older
1
import math
2
import os
3
4
5
from abc import abstractmethod

import cv2
6
7
8
9
import numpy as np
import torch
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset

10
11
12
13
14

class Txt2ImgIterableBaseDataset(IterableDataset):
    '''
    Define an interface to make the IterableDatasets for text2img data chainable
    '''
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    def __init__(self, file_path: str, rank, world_size):
        super().__init__()
        self.file_path = file_path
        self.folder_list = []
        self.file_list = []
        self.txt_list = []
        self.info = self._get_file_info(file_path)
        self.start = self.info['start']
        self.end = self.info['end']
        self.rank = rank

        self.world_size = world_size
        # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
        # self.iter_start = self.start + self.rank * self.per_worker
        # self.iter_end = min(self.iter_start + self.per_worker, self.end)
        # self.num_records = self.iter_end - self.iter_start
        # self.valid_ids = [i for i in range(self.iter_end)]
        self.num_records = self.end - self.start
        self.valid_ids = [i for i in range(self.end)]

        print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')

    def __len__(self):
        # return self.iter_end - self.iter_start
        return self.end - self.start

    def __iter__(self):
        sample_iterator = self._sample_generator(self.start, self.end)
        # sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
        return sample_iterator

    def _sample_generator(self, start, end):
        for idx in range(start, end):
            file_name = self.file_list[idx]
            txt_name = self.txt_list[idx]
            f_ = open(txt_name, 'r')
            txt_ = f_.read()
            f_.close()
            image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = torch.from_numpy(image) / 255
57
            yield {"txt": txt_, "image": image}
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    def _get_file_info(self, file_path):
        info = \
        {
            "start": 1,
            "end": 0,
        }
        self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
        for folder in self.folder_list:
            files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
            txts = [k.replace('jpg', 'txt') for k in files]
            self.file_list.extend(files)
            self.txt_list.extend(txts)
        info['end'] = len(self.file_list)
        # with open(file_path, 'r') as fin:
        #     for _ in enumerate(fin):
        #         info['end'] += 1
        # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
76
        return info