"nndet/arch/vscode:/vscode.git/clone" did not exist on "d07c53230fd383134309ed72ae5659e5f996a7e4"
metric_utils.py 9.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Copyright (c) OpenMMLab. All rights reserved.
import sys

import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.parallel import is_module_wrapper

from mmgen.models.architectures.common import get_module_device


@torch.no_grad()
def extract_inception_features(dataloader,
                               inception,
                               num_samples,
                               inception_style='pytorch'):
    """Extract inception features for FID metric.

    Args:
        dataloader (:obj:`DataLoader`): Dataloader for images.
        inception (nn.Module): Inception network.
        num_samples (int): The number of samples to be extracted.
        inception_style (str): The style of Inception network, "pytorch" or
            "stylegan". Defaults to "pytorch".

    Returns:
        torch.Tensor: Inception features.
    """
    batch_size = dataloader.batch_size
    num_iters = num_samples // batch_size
    if num_iters * batch_size < num_samples:
        num_iters += 1
    # define mmcv progress bar
    pbar = mmcv.ProgressBar(num_iters)

    feature_list = []
    curr_iter = 1
    for data in dataloader:
        # a dirty walkround to support multiple datasets (mainly for the
        # unconditional dataset and conditional dataset). In our
        # implementation, unconditioanl dataset will return real images with
        # the key "real_img". However, the conditional dataset contains a key
        # "img" denoting the real images.
        if 'real_img' in data:
            # Mainly for the unconditional dataset in our MMGeneration
            img = data['real_img']
        else:
            # Mainly for conditional dataset in MMClassification
            img = data['img']
        pbar.update()

        # the inception network is not wrapped with module wrapper.
        if not is_module_wrapper(inception):
            # put the img to the module device
            img = img.to(get_module_device(inception))

        if inception_style == 'stylegan':
            img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            feature = inception(img, return_features=True)
        else:
            feature = inception(img)[0].view(img.shape[0], -1)
        feature_list.append(feature.to('cpu'))

        if curr_iter >= num_iters:
            break
        curr_iter += 1

    # Attention: the number of features may be different as you want.
    features = torch.cat(feature_list, 0)

    assert features.shape[0] >= num_samples
    features = features[:num_samples]

    # to change the line after pbar
    sys.stdout.write('\n')
    return features


def _hox_downsample(img):
    r"""Downsample images with factor equal to 0.5.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py  # noqa

    Args:
        img (ndarray): Images with order "NHWC".

    Returns:
        ndarray: Downsampled images with order "NHWC".
    """
    return (img[:, 0::2, 0::2, :] + img[:, 1::2, 0::2, :] +
            img[:, 0::2, 1::2, :] + img[:, 1::2, 1::2, :]) * 0.25


def _f_special_gauss(size, sigma):
    r"""Return a circular symmetric gaussian kernel.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py  # noqa

    Args:
        size (int): Size of Gaussian kernel.
        sigma (float): Standard deviation for Gaussian blur kernel.

    Returns:
        ndarray: Gaussian kernel.
    """
    radius = size // 2
    offset = 0.0
    start, stop = -radius, radius + 1
    if size % 2 == 0:
        offset = 0.5
        stop -= 1
    x, y = np.mgrid[offset + start:stop, offset + start:stop]
    assert len(x) == size
    g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
    return g / g.sum()


# Gaussian blur kernel
def get_gaussian_kernel():
    kernel = np.array([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6],
                       [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]],
                      np.float32) / 256.0
    gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5))
    return gaussian_k


def get_pyramid_layer(image, gaussian_k, direction='down'):
    gaussian_k = gaussian_k.to(image.device)
    if direction == 'up':
        image = F.interpolate(image, scale_factor=2)
    multiband = [
        F.conv2d(
            image[:, i:i + 1, :, :],
            gaussian_k,
            padding=2,
            stride=1 if direction == 'up' else 2) for i in range(3)
    ]
    image = torch.cat(multiband, dim=1)
    return image


def gaussian_pyramid(original, n_pyramids, gaussian_k):
    x = original
    # pyramid down
    pyramids = [original]
    for _ in range(n_pyramids):
        x = get_pyramid_layer(x, gaussian_k)
        pyramids.append(x)
    return pyramids


def laplacian_pyramid(original, n_pyramids, gaussian_k):
    """Calculate Laplacian pyramid.

    Ref: https://github.com/koshian2/swd-pytorch/blob/master/swd.py

    Args:
        original (Tensor): Batch of Images with range [0, 1] and order "NCHW"
        n_pyramids (int): Levels of pyramids minus one.
        gaussian_k (Tensor): Gaussian kernel with shape (1, 1, 5, 5).

    Return:
        list[Tensor]. Laplacian pyramids of original.
    """
    # create gaussian pyramid
    pyramids = gaussian_pyramid(original, n_pyramids, gaussian_k)

    # pyramid up - diff
    laplacian = []
    for i in range(len(pyramids) - 1):
        diff = pyramids[i] - get_pyramid_layer(pyramids[i + 1], gaussian_k,
                                               'up')
        laplacian.append(diff)
    # Add last gaussian pyramid
    laplacian.append(pyramids[len(pyramids) - 1])
    return laplacian


def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image):
    r"""Get descriptors of one level of pyramids.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py  # noqa

    Args:
        minibatch (Tensor): Pyramids of one level with order "NCHW".
        nhood_size (int): Pixel neighborhood size.
        nhoods_per_image (int): The number of descriptors per image.

    Return:
        Tensor: Descriptors of images from one level batch.
    """
    S = minibatch.shape  # (minibatch, channel, height, width)
    assert len(S) == 4 and S[1] == 3
    N = nhoods_per_image * S[0]
    H = nhood_size // 2
    nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H + 1, -H:H + 1]
    img = nhood // nhoods_per_image
    x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1))
    y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1))
    idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x
    return minibatch.view(-1)[idx]


def finalize_descriptors(desc):
    r"""Normalize and reshape descriptors.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py  # noqa

    Args:
        desc (list or Tensor): List of descriptors of one level.

    Return:
        Tensor: Descriptors after normalized along channel and flattened.
    """
    if isinstance(desc, list):
        desc = torch.cat(desc, dim=0)
    assert desc.ndim == 4  # (neighborhood, channel, height, width)
    desc -= torch.mean(desc, dim=(0, 2, 3), keepdim=True)
    desc /= torch.std(desc, dim=(0, 2, 3), keepdim=True)
    desc = desc.reshape(desc.shape[0], -1)
    return desc


def compute_pr_distances(row_features,
                         col_features,
                         num_gpus,
                         rank,
                         col_batch_size=10000):
    r"""Compute distances between real images and fake images.

    This function is used for calculate Precision and Recall metric.
    Refer to:https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py  # noqa
    """
    assert 0 <= rank < num_gpus
    num_cols = col_features.shape[0]
    num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
    col_batches = torch.nn.functional.pad(
        col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
    dist_batches = []
    for col_batch in col_batches[rank::num_gpus]:
        dist_batch = torch.cdist(
            row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
        for src in range(num_gpus):
            dist_broadcast = dist_batch.clone()
            if num_gpus > 1:
                torch.distributed.broadcast(dist_broadcast, src=src)
            dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
    return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None


def normalize(a):
    """L2 normalization.

    Args:
        a (Tensor): Tensor with shape [N, C].

    Returns:
        Tensor: Tensor after L2 normalization per-instance.
    """
    return a / torch.norm(a, dim=1, keepdim=True)


def slerp(a, b, percent):
    """Spherical linear interpolation between two unnormalized vectors.

    Args:
        a (Tensor): Tensor with shape [N, C].
        b (Tensor): Tensor with shape [N, C].
        percent (float|Tensor): A float or tensor with shape broadcastable to
            the shape of input Tensors.

    Returns:
        Tensor: Spherical linear interpolation result with shape [N, C].
    """
    a = normalize(a)
    b = normalize(b)
    d = (a * b).sum(-1, keepdim=True)
    p = percent * torch.acos(d)
    c = normalize(b - d * a)
    d = a * torch.cos(p) + c * torch.sin(p)

    return normalize(d)