cifar10.py 6.52 KB
Newer Older
1
2
import json
from pathlib import Path
3
from typing import Dict
4

5
import torch
6
from datasets import load_dataset
7
8
from einops import rearrange
from ldm.util import instantiate_from_config
9
10
11
12
13
from omegaconf import DictConfig, ListConfig
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

14
15
16
17
18
19
20
21
22

def make_multi_folder_data(paths, caption_files=None, **kwargs):
    """Make a concat dataset from multiple folders
    Don't suport captions yet
    If paths is a list, that's ok, if it's a Dict interpret it as:
    k=folder v=n_times to repeat that
    """
    list_of_paths = []
    if isinstance(paths, (Dict, DictConfig)):
23
        assert caption_files is None, "Caption files not yet supported for repeats"
24
        for folder_path, repeats in paths.items():
25
            list_of_paths.extend([folder_path] * repeats)
26
27
28
29
30
31
32
33
        paths = list_of_paths

    if caption_files is not None:
        datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
    else:
        datasets = [FolderData(p, **kwargs) for p in paths]
    return torch.utils.data.ConcatDataset(datasets)

34

35
class FolderData(Dataset):
36
37
    def __init__(
        self,
38
39
40
41
42
43
44
        root_dir,
        caption_file=None,
        image_transforms=[],
        ext="jpg",
        default_caption="",
        postprocess=None,
        return_paths=False,
45
    ) -> None:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        """Create a dataset from a folder of images.
        If you pass in a root directory it will be searched for images
        ending in ext (ext can be a list)
        """
        self.root_dir = Path(root_dir)
        self.default_caption = default_caption
        self.return_paths = return_paths
        if isinstance(postprocess, DictConfig):
            postprocess = instantiate_from_config(postprocess)
        self.postprocess = postprocess
        if caption_file is not None:
            with open(caption_file, "rt") as f:
                ext = Path(caption_file).suffix.lower()
                if ext == ".json":
                    captions = json.load(f)
                elif ext == ".jsonl":
                    lines = f.readlines()
                    lines = [json.loads(x) for x in lines]
                    captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
                else:
                    raise ValueError(f"Unrecognised format: {ext}")
            self.captions = captions
        else:
            self.captions = None

        if not isinstance(ext, (tuple, list, ListConfig)):
            ext = [ext]

        # Only used if there is no caption file
        self.paths = []
        for e in ext:
            self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
        if isinstance(image_transforms, ListConfig):
            image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
80
81
82
        image_transforms.extend(
            [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
        )
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        image_transforms = transforms.Compose(image_transforms)
        self.tform = image_transforms

    def __len__(self):
        if self.captions is not None:
            return len(self.captions.keys())
        else:
            return len(self.paths)

    def __getitem__(self, index):
        data = {}
        if self.captions is not None:
            chosen = list(self.captions.keys())[index]
            caption = self.captions.get(chosen, None)
            if caption is None:
                caption = self.default_caption
99
            filename = self.root_dir / chosen
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        else:
            filename = self.paths[index]

        if self.return_paths:
            data["path"] = str(filename)

        im = Image.open(filename)
        im = self.process_im(im)
        data["image"] = im

        if self.captions is not None:
            data["txt"] = caption
        else:
            data["txt"] = self.default_caption

        if self.postprocess is not None:
            data = self.postprocess(data)

        return data

    def process_im(self, im):
        im = im.convert("RGB")
        return self.tform(im)

124

125
126
127
128
129
130
def hf_dataset(
    name,
    image_transforms=[],
    image_column="img",
    label_column="label",
    text_column="txt",
131
132
133
134
135
    split="train",
    image_key="image",
    caption_key="txt",
):
    """Make huggingface dataset with appropriate list of transforms applied"""
136
137
    ds = load_dataset(name, split=split)
    image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
138
139
140
    image_transforms.extend(
        [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))]
    )
141
142
143
144
145
146
147
148
149
    tform = transforms.Compose(image_transforms)

    assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
    assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}"

    def pre_process(examples):
        processed = {}
        processed[image_key] = [tform(im) for im in examples[image_column]]

150
151
152
153
154
155
156
157
158
159
160
161
        label_to_text_dict = {
            0: "airplane",
            1: "automobile",
            2: "bird",
            3: "cat",
            4: "deer",
            5: "dog",
            6: "frog",
            7: "horse",
            8: "ship",
            9: "truck",
        }
162
163
164
165
166
167
168
169

        processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]]

        return processed

    ds.set_transform(pre_process)
    return ds

170

171
172
173
174
175
176
177
178
179
180
181
182
183
class TextOnly(Dataset):
    def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
        """Returns only captions with dummy images"""
        self.output_size = output_size
        self.image_key = image_key
        self.caption_key = caption_key
        if isinstance(captions, Path):
            self.captions = self._load_caption_file(captions)
        else:
            self.captions = captions

        if n_gpus > 1:
            # hack to make sure that all the captions appear on each gpu
184
            repeated = [n_gpus * [x] for x in self.captions]
185
186
187
188
189
190
191
192
            self.captions = []
            [self.captions.extend(x) for x in repeated]

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        dummy_im = torch.zeros(3, self.output_size, self.output_size)
193
        dummy_im = rearrange(dummy_im * 2.0 - 1.0, "c h w -> h w c")
194
195
196
        return {self.image_key: dummy_im, self.caption_key: self.captions[index]}

    def _load_caption_file(self, filename):
197
        with open(filename, "rt") as f:
198
            captions = f.readlines()
199
        return [x.strip("\n") for x in captions]