singan_dataset.py 4.05 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from torch.utils.data import Dataset

from .builder import DATASETS


def create_real_pyramid(real, min_size, max_size, scale_factor_init):
    """Create image pyramid.

    This function is modified from the official implementation:
    https://github.com/tamarott/SinGAN/blob/master/SinGAN/functions.py#L221

    In this implementation, we adopt the rescaling function from MMCV.
    Args:
        real (np.array): The real image array.
        min_size (int): The minimum size for the image pyramid.
        max_size (int): The maximum size for the image pyramid.
        scale_factor_init (float): The initial scale factor.
    """

    num_scales = int(
        np.ceil(
            np.log(np.power(min_size / min(real.shape[0], real.shape[1]), 1)) /
            np.log(scale_factor_init))) + 1

    scale2stop = int(
        np.ceil(
            np.log(
                min([max_size, max([real.shape[0], real.shape[1]])]) /
                max([real.shape[0], real.shape[1]])) /
            np.log(scale_factor_init)))

    stop_scale = num_scales - scale2stop

    scale1 = min(max_size / max([real.shape[0], real.shape[1]]), 1)
    real_max = mmcv.imrescale(real, scale1)
    scale_factor = np.power(
        min_size / (min(real_max.shape[0], real_max.shape[1])),
        1 / (stop_scale))

    scale2stop = int(
        np.ceil(
            np.log(
                min([max_size, max([real.shape[0], real.shape[1]])]) /
                max([real.shape[0], real.shape[1]])) /
            np.log(scale_factor_init)))
    stop_scale = num_scales - scale2stop

    reals = []
    for i in range(stop_scale + 1):
        scale = np.power(scale_factor, stop_scale - i)
        curr_real = mmcv.imrescale(real, scale)
        reals.append(curr_real)

    return reals, scale_factor, stop_scale


@DATASETS.register_module()
class SinGANDataset(Dataset):
    """SinGAN Dataset.

    In this dataset, we create an image pyramid and save it in the cache.

    Args:
        img_path (str): Path to the single image file.
        min_size (int): Min size of the image pyramid. Here, the number will be
            set to the ``min(H, W)``.
        max_size (int): Max size of the image pyramid. Here, the number will be
            set to the ``max(H, W)``.
        scale_factor_init (float): Rescale factor. Note that the actual factor
            we use may be a little bit different from this value.
        num_samples (int, optional): The number of samples (length) in this
            dataset. Defaults to -1.
    """

    def __init__(self,
                 img_path,
                 min_size,
                 max_size,
                 scale_factor_init,
                 num_samples=-1):
        self.img_path = img_path
        assert mmcv.is_filepath(self.img_path)
        self.load_annotations(min_size, max_size, scale_factor_init)
        self.num_samples = num_samples

    def load_annotations(self, min_size, max_size, scale_factor_init):
        """Load annatations for SinGAN Dataset.

        Args:
            min_size (int): The minimum size for the image pyramid.
            max_size (int): The maximum size for the image pyramid.
            scale_factor_init (float): The initial scale factor.
        """
        real = mmcv.imread(self.img_path)
        self.reals, self.scale_factor, self.stop_scale = create_real_pyramid(
            real, min_size, max_size, scale_factor_init)

        self.data_dict = {}

        for i, real in enumerate(self.reals):
            self.data_dict[f'real_scale{i}'] = self._img2tensor(real)

        self.data_dict['input_sample'] = torch.zeros_like(
            self.data_dict['real_scale0'])

    def _img2tensor(self, img):
        img = torch.from_numpy(img).to(torch.float32).permute(2, 0,
                                                              1).contiguous()
        img = (img / 255 - 0.5) * 2

        return img

    def __getitem__(self, index):
        return self.data_dict

    def __len__(self):
        return int(1e6) if self.num_samples < 0 else self.num_samples