lsun.py 5.98 KB
Newer Older
1
2
3
4
5
6
7
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

NatalieC323's avatar
NatalieC323 committed
8
# This class is used to create a dataset of images from LSUN dataset for training
9
10
class LSUNBase(Dataset):
    def __init__(self,
NatalieC323's avatar
NatalieC323 committed
11
12
13
14
15
                 txt_file,                  # path to the text file containing the list of image paths
                 data_root,                 # root directory of the LSUN dataset
                 size=None,                 # the size of images to resize to
                 interpolation="bicubic",   # interpolation method to be used while resizing
                 flip_p=0.5                 # probability of random horizontal flipping
16
                 ):
NatalieC323's avatar
NatalieC323 committed
17
18
19
20
21
22
23
        self.data_paths = txt_file          # store path to text file containing list of images
        self.data_root = data_root          # store path to root directory of the dataset
        with open(self.data_paths, "r") as f:        # open and read the text file
            self.image_paths = f.read().splitlines() # read the lines of the file and store as list
        self._length = len(self.image_paths)         # store the number of images
    
        # create dictionary to hold image path information
24
25
26
27
28
29
        self.labels = {
            "relative_file_path_": [l for l in self.image_paths],
            "file_path_": [os.path.join(self.data_root, l)
                           for l in self.image_paths],
        }

NatalieC323's avatar
NatalieC323 committed
30
31
32
        # set the image size to be resized
        self.size = size  
        # set the interpolation method for resizing the image
33
34
35
36
37
        self.interpolation = {"linear": PIL.Image.LINEAR,
                              "bilinear": PIL.Image.BILINEAR,
                              "bicubic": PIL.Image.BICUBIC,
                              "lanczos": PIL.Image.LANCZOS,
                              }[interpolation]
NatalieC323's avatar
NatalieC323 committed
38
        # randomly flip the image horizontally with a given probability
39
40
41
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

    def __len__(self):
NatalieC323's avatar
NatalieC323 committed
42
        # return the length of dataset
43
        return self._length
NatalieC323's avatar
NatalieC323 committed
44
    
45
46

    def __getitem__(self, i):
NatalieC323's avatar
NatalieC323 committed
47
        # get the image path for the given index
48
49
        example = dict((k, self.labels[k][i]) for k in self.labels)
        image = Image.open(example["file_path_"])
NatalieC323's avatar
NatalieC323 committed
50
        # convert it to RGB format
51
52
53
54
        if not image.mode == "RGB":
            image = image.convert("RGB")

        # default to score-sde preprocessing
NatalieC323's avatar
NatalieC323 committed
55
56
57
58
        
        img = np.array(image).astype(np.uint8)      # convert image to numpy array
        crop = min(img.shape[0], img.shape[1])      # crop the image to a square shape
        h, w, = img.shape[0], img.shape[1]          # get the height and width of image
59
        img = img[(h - crop) // 2:(h + crop) // 2,
NatalieC323's avatar
NatalieC323 committed
60
              (w - crop) // 2:(w + crop) // 2]      # crop the image to a square shape
61

NatalieC323's avatar
NatalieC323 committed
62
63
        image = Image.fromarray(img)                # create an image from numpy array
        if self.size is not None:                   # if image size is provided, resize the image
64
65
            image = image.resize((self.size, self.size), resample=self.interpolation)

NatalieC323's avatar
NatalieC323 committed
66
67
68
69
        image = self.flip(image)                    # flip the image horizontally with the given probability
        image = np.array(image).astype(np.uint8)    
        example["image"] = (image / 127.5 - 1.0).astype(np.float32)  # normalize the image values and convert to float32
        return example                              # return the example dictionary containing the image and its file paths
70

NatalieC323's avatar
NatalieC323 committed
71
72
73
#A dataset class for LSUN Churches training set. 
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. 
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class.
74
75
76
77
class LSUNChurchesTrain(LSUNBase):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)

NatalieC323's avatar
NatalieC323 committed
78
79
#A dataset class for LSUN Churches validation set. 
# It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default.
80
81
82
83
84
class LSUNChurchesValidation(LSUNBase):
    def __init__(self, flip_p=0., **kwargs):
        super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
                         flip_p=flip_p, **kwargs)

NatalieC323's avatar
NatalieC323 committed
85
86
# A dataset class for LSUN Bedrooms training set.  
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. 
87
88
89
90
class LSUNBedroomsTrain(LSUNBase):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)

NatalieC323's avatar
NatalieC323 committed
91
92
# A dataset class for LSUN Bedrooms validation set. 
# It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default.
93
94
95
96
97
class LSUNBedroomsValidation(LSUNBase):
    def __init__(self, flip_p=0.0, **kwargs):
        super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
                         flip_p=flip_p, **kwargs)

NatalieC323's avatar
NatalieC323 committed
98
99
100
# A dataset class for LSUN Cats training set. 
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. 
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments.
101
102
103
104
class LSUNCatsTrain(LSUNBase):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)

NatalieC323's avatar
NatalieC323 committed
105
106
# A dataset class for LSUN Cats validation set. 
# It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default.
107
108
109
110
class LSUNCatsValidation(LSUNBase):
    def __init__(self, flip_p=0., **kwargs):
        super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
                         flip_p=flip_p, **kwargs)