lsun.py 5.67 KB
Newer Older
1
import os
2

3
4
5
6
7
8
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

9

NatalieC323's avatar
NatalieC323 committed
10
# This class is used to create a dataset of images from LSUN dataset for training
11
class LSUNBase(Dataset):
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    def __init__(
        self,
        txt_file,  # path to the text file containing the list of image paths
        data_root,  # root directory of the LSUN dataset
        size=None,  # the size of images to resize to
        interpolation="bicubic",  # interpolation method to be used while resizing
        flip_p=0.5,  # probability of random horizontal flipping
    ):
        self.data_paths = txt_file  # store path to text file containing list of images
        self.data_root = data_root  # store path to root directory of the dataset
        with open(self.data_paths, "r") as f:  # open and read the text file
            self.image_paths = f.read().splitlines()  # read the lines of the file and store as list
        self._length = len(self.image_paths)  # store the number of images

NatalieC323's avatar
NatalieC323 committed
26
        # create dictionary to hold image path information
27
28
        self.labels = {
            "relative_file_path_": [l for l in self.image_paths],
29
            "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths],
30
31
        }

NatalieC323's avatar
NatalieC323 committed
32
        # set the image size to be resized
33
        self.size = size
NatalieC323's avatar
NatalieC323 committed
34
        # set the interpolation method for resizing the image
35
36
37
38
39
40
        self.interpolation = {
            "linear": PIL.Image.LINEAR,
            "bilinear": PIL.Image.BILINEAR,
            "bicubic": PIL.Image.BICUBIC,
            "lanczos": PIL.Image.LANCZOS,
        }[interpolation]
NatalieC323's avatar
NatalieC323 committed
41
        # randomly flip the image horizontally with a given probability
42
43
44
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

    def __len__(self):
NatalieC323's avatar
NatalieC323 committed
45
        # return the length of dataset
46
47
48
        return self._length

    def __getitem__(self, i):
NatalieC323's avatar
NatalieC323 committed
49
        # get the image path for the given index
50
51
        example = dict((k, self.labels[k][i]) for k in self.labels)
        image = Image.open(example["file_path_"])
NatalieC323's avatar
NatalieC323 committed
52
        # convert it to RGB format
53
54
55
56
        if not image.mode == "RGB":
            image = image.convert("RGB")

        # default to score-sde preprocessing
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

        img = np.array(image).astype(np.uint8)  # convert image to numpy array
        crop = min(img.shape[0], img.shape[1])  # crop the image to a square shape
        (
            h,
            w,
        ) = (
            img.shape[0],
            img.shape[1],
        )  # get the height and width of image
        img = img[
            (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
        ]  # crop the image to a square shape

        image = Image.fromarray(img)  # create an image from numpy array
        if self.size is not None:  # if image size is provided, resize the image
73
74
            image = image.resize((self.size, self.size), resample=self.interpolation)

75
76
        image = self.flip(image)  # flip the image horizontally with the given probability
        image = np.array(image).astype(np.uint8)
NatalieC323's avatar
NatalieC323 committed
77
        example["image"] = (image / 127.5 - 1.0).astype(np.float32)  # normalize the image values and convert to float32
78
79
        return example  # return the example dictionary containing the image and its file paths

80

81
82
# A dataset class for LSUN Churches training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
NatalieC323's avatar
NatalieC323 committed
83
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class.
84
85
86
87
class LSUNChurchesTrain(LSUNBase):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)

88
89

# A dataset class for LSUN Churches validation set.
NatalieC323's avatar
NatalieC323 committed
90
# It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default.
91
class LSUNChurchesValidation(LSUNBase):
92
93
94
95
96
    def __init__(self, flip_p=0.0, **kwargs):
        super().__init__(
            txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs
        )

97

98
99
# A dataset class for LSUN Bedrooms training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
100
101
102
103
class LSUNBedroomsTrain(LSUNBase):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)

104
105

# A dataset class for LSUN Bedrooms validation set.
NatalieC323's avatar
NatalieC323 committed
106
# It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default.
107
108
class LSUNBedroomsValidation(LSUNBase):
    def __init__(self, flip_p=0.0, **kwargs):
109
        super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs)
110

111
112
113

# A dataset class for LSUN Cats training set.
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
NatalieC323's avatar
NatalieC323 committed
114
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments.
115
116
117
118
class LSUNCatsTrain(LSUNBase):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)

119
120

# A dataset class for LSUN Cats validation set.
NatalieC323's avatar
NatalieC323 committed
121
# It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default.
122
class LSUNCatsValidation(LSUNBase):
123
124
    def __init__(self, flip_p=0.0, **kwargs):
        super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs)