utils.py 11.4 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def _get_noise_batch(noise,
                     image_shape,
                     num_timesteps=0,
                     num_batches=0,
                     timesteps_noise=False):
    """Get noise batch. Support get sequeue of noise along timesteps.

    We support the following use cases ('bz' denotes ```num_batches`` and 'n'
    denotes ``num_timesteps``):

    If timesteps_noise is True, we output noise which dimension is 5.
    - Input is [bz, c, h, w]: Expand to [n, bz, c, h, w]
    - Input is [n, c, h, w]: Expand to [n, bz, c, h, w]
    - Input is [n*bz, c, h, w]: View to [n, bz, c, h, w]
    - Dim of the input is 5: Return the input, ignore ``num_batches`` and
      ``num_timesteps``
    - Callable or None: Generate noise shape as [n, bz, c, h, w]
    - Otherwise: Raise error

    If timestep_noise is False, we output noise which dimension is 4 and
    ignore ``num_timesteps``.
    - Dim of the input is 3: Unsqueeze to [1, c, h, w], ignore ``num_batches``
    - Dim of the input is 4: Return input, ignore ``num_batches``
    - Callable or None: Generate noise shape as [bz, c, h, w]
    - Otherwise: Raise error

    It's to be noted that, we do not move the generated label to target device
    in this function because we can not get which device the noise should move
    to.

    Args:
        noise (torch.Tensor | callable | None): You can directly give a
            batch of noise through a ``torch.Tensor`` or offer a callable
            function to sample a batch of noise data. Otherwise, the
            ``None`` indicates to use the default noise sampler.
        image_shape (torch.Size): Size of images in the diffusion process.
        num_timesteps (int, optional): Total timestpes of the diffusion and
            denoising process. Defaults to 0.
        num_batches (int, optional): The number of batch size. To be noted that
            this argument only work when the input ``noise`` is callable or
            ``None``. Defaults to 0.
        timesteps_noise (bool, optional): If True, returned noise will shape
            as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
            Defaults to False.
        device (str, optional): If not ``None``, move the generated noise to
            corresponding device.
    Returns:
        torch.Tensor: Generated noise with desired shape.
    """
    if isinstance(noise, torch.Tensor):
        # conduct sanity check for the last three dimension
        assert noise.shape[-3:] == image_shape
        if timesteps_noise:
            if noise.ndim == 4:
                assert num_batches > 0 and num_timesteps > 0
                # noise shape as [n, c, h, w], expand to [n, bz, c, h, w]
                if noise.shape[0] == num_timesteps:
                    noise_batch = noise.view(num_timesteps, 1, *image_shape)
                    noise_batch = noise_batch.expand(-1, num_batches, -1, -1,
                                                     -1)
                # noise shape as [bz, c, h, w], expand to [n, bz, c, h, w]
                elif noise.shape[0] == num_batches:
                    noise_batch = noise.view(1, num_batches, *image_shape)
                    noise_batch = noise_batch.expand(num_timesteps, -1, -1, -1,
                                                     -1)
                # noise shape as [n*bz, c, h, w], reshape to [b, bz, c, h, w]
                elif noise.shape[0] == num_timesteps * num_batches:
                    noise_batch = noise.view(num_timesteps, -1, *image_shape)
                else:
                    raise ValueError(
                        'The timesteps noise should be in shape of '
                        '(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
                        f'(n, bz, c, h, w). But receive {noise.shape}.')

            elif noise.ndim == 5:
                # direct return noise
                noise_batch = noise
            else:
                raise ValueError(
                    'The timesteps noise should be in shape of '
                    '(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
                    f'(n, bz, c, h, w). But receive {noise.shape}.')
        else:
            if noise.ndim == 3:
                # reshape noise to [1, c, h, w]
                noise_batch = noise[None, ...]
            elif noise.ndim == 4:
                # do nothing
                noise_batch = noise
            else:
                raise ValueError(
                    'The noise should be in shape of (n, c, h, w) or'
                    f'(c, h, w), but got {noise.shape}')
    # receive a noise generator and sample noise.
    elif callable(noise):
        assert num_batches > 0
        noise_generator = noise
        if timesteps_noise:
            assert num_timesteps > 0
            # generate noise shape as [n, bz, c, h, w]
            noise_batch = noise_generator(
                (num_timesteps, num_batches, *image_shape))
        else:
            # generate noise shape as [bz, c, h, w]
            noise_batch = noise_generator((num_batches, *image_shape))
    # otherwise, we will adopt default noise sampler.
    else:
        assert num_batches > 0
        if timesteps_noise:
            assert num_timesteps > 0
            # generate noise shape as [n, bz, c, h, w]
            noise_batch = torch.randn(
                (num_timesteps, num_batches, *image_shape))
        else:
            # generate noise shape as [bz, c, h, w]
            noise_batch = torch.randn((num_batches, *image_shape))

    return noise_batch


def _get_label_batch(label,
                     num_timesteps=0,
                     num_classes=0,
                     num_batches=0,
                     timesteps_noise=False):
    """Get label batch. Support get sequeue of label along timesteps.

    We support the following use cases ('bz' denotes ```num_batches`` and 'n'
    denotes ``num_timesteps``):

    If num_classes <= 0, return None.

    If timesteps_noise is True, we output label which dimension is 2.
    - Input is [bz, ]: Expand to [n, bz]
    - Input is [n, ]: Expand to [n, bz]
    - Input is [n*bz, ]: View to [n, bz]
    - Dim of the input is 2: Return the input, ignore ``num_batches`` and
      ``num_timesteps``
    - Callable or None: Generate label shape as [n, bz]
    - Otherwise: Raise error

    If timesteps_noise is False, we output label which dimension is 1 and
    ignore ``num_timesteps``.
    - Dim of the input is 1: Unsqueeze to [1, ], ignore ``num_batches``
    - Dim of the input is 2: Return the input. ignore ``num_batches``
    - Callable or None: Generate label shape as [bz, ]
    - Otherwise: Raise error

    It's to be noted that, we do not move the generated label to target device
    in this function because we can not get which device the noise should move
    to.

    Args:
        label (torch.Tensor | callable | None): You can directly give a
            batch of noise through a ``torch.Tensor`` or offer a callable
            function to sample a batch of noise data. Otherwise, the
            ``None`` indicates to use the default noise sampler.
        num_timesteps (int, optional): Total timestpes of the diffusion and
            denoising process. Defaults to 0.
        num_batches (int, optional): The number of batch size. To be noted that
            this argument only work when the input ``noise`` is callable or
            ``None``. Defaults to 0.
        timesteps_noise (bool, optional): If True, returned noise will shape
            as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
            Defaults to False.
    Returns:
        torch.Tensor: Generated label with desired shape.
    """
    # no labels output if num_classes is 0
    if num_classes == 0:
        assert label is None, ('\'label\' should be None '
                               'if \'num_classes == 0\'.')
        return None

    # receive label and conduct sanity check.
    if isinstance(label, torch.Tensor):
        if timesteps_noise:
            if label.ndim == 1:
                assert num_batches > 0 and num_timesteps > 0
                # [n, ] to [n, bz]
                if label.shape[0] == num_timesteps:
                    label_batch = label.view(num_timesteps, 1)
                    label_batch = label_batch.expand(-1, num_batches)
                # [bz, ] to [n, bz]
                elif label.shape[0] == num_batches:
                    label_batch = label.view(1, num_batches)
                    label_batch = label_batch.expand(num_timesteps, -1)
                # [n*bz, ] to [n, bz]
                elif label.shape[0] == num_timesteps * num_batches:
                    label_batch = label.view(num_timesteps, -1)
                else:
                    raise ValueError(
                        'The timesteps label should be in shape of '
                        '(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
                        f'{label.shape}.')

            elif label.ndim == 2:
                # dimension is 2, direct return
                label_batch = label
            else:
                raise ValueError(
                    'The timesteps label should be in shape of '
                    '(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
                    f'{label.shape}.')
        else:
            # dimension is 0, expand to [1, ]
            if label.ndim == 0:
                label_batch = label[None, ...]
            # dimension is 1, do nothing
            elif label.ndim == 1:
                label_batch = label
            else:
                raise ValueError(
                    'The label should be in shape of (bz, ) or'
                    f'zero-dimension tensor, but got {label.shape}')
    # receive a noise generator and sample noise.
    elif callable(label):
        assert num_batches > 0
        label_generator = label
        if timesteps_noise:
            assert num_timesteps > 0
            # generate label shape as [n, bz]
            label_batch = label_generator((num_timesteps, num_batches))
        else:
            # generate label shape as [bz, ]
            label_batch = label_generator((num_batches, ))
    # otherwise, we will adopt default label sampler.
    else:
        assert num_batches > 0
        if timesteps_noise:
            assert num_timesteps > 0
            # generate label shape as [n, bz]
            label_batch = torch.randint(0, num_classes,
                                        (num_timesteps, num_batches))
        else:
            # generate label shape as [bz, ]
            label_batch = torch.randint(0, num_classes, (num_batches, ))

    return label_batch


def var_to_tensor(var, index, target_shape=None, device=None):
    """Function used to extract variables by given index, and convert into
    tensor as given shape.
    Args:
        var (np.array): Variables to be extracted.
        index (torch.Tensor): Target index to extract.
        target_shape (torch.Size, optional): If given, the indexed variable
            will expand to the given shape. Defaults to None.
        device (str): If given, the indexed variable will move to the target
            device. Otherwise, indexed variable will on cpu. Defaults to None.

    Returns:
        torch.Tensor: Converted variable.
    """
    # we must move var to cuda for it's ndarray in current design
    var_indexed = torch.from_numpy(var)[index.cpu()].float()

    if device is not None:
        var_indexed = var_indexed.to(device)

    while len(var_indexed.shape) < len(target_shape):
        var_indexed = var_indexed[..., None]
    return var_indexed