evaluation.py 14.3 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil
import sys
from copy import deepcopy

import mmcv
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from prettytable import PrettyTable
from torchvision.utils import save_image

from mmgen.datasets import build_dataloader, build_dataset


def make_metrics_table(train_cfg, ckpt, eval_info, metrics):
    """Arrange evaluation results into a table.

    Args:
        train_cfg (str): Name of the training configuration.
        ckpt (str): Path of the evaluated model's weights.
        metrics (Metric): Metric objects.

    Returns:
        str: String of the eval table.
    """
    table = PrettyTable()
    table.set_style(14)
    table.add_column('Training configuration', [train_cfg])
    table.add_column('Checkpoint', [ckpt])
    table.add_column('Eval', [eval_info])
    for metric in metrics:
        table.add_column(metric.name, [metric.result_str])
    return table.get_string()


def make_vanilla_dataloader(img_path, batch_size, dist=False):
    pipeline = [
        dict(type='LoadImageFromFile', key='real_img', io_backend='disk'),
        dict(
            type='Normalize',
            keys=['real_img'],
            mean=[127.5] * 3,
            std=[127.5] * 3,
            to_rgb=False),
        dict(type='ImageToTensor', keys=['real_img']),
        dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
    ]
    dataset = build_dataset(
        dict(
            type='UnconditionalImageDataset',
            imgs_root=img_path,
            pipeline=pipeline,
        ))
    dataloader = build_dataloader(
        dataset,
        samples_per_gpu=batch_size,
        workers_per_gpu=4,
        dist=dist,
        shuffle=True)
    return dataloader


@torch.no_grad()
def offline_evaluation(model,
                       data_loader,
                       metrics,
                       logger,
                       basic_table_info,
                       batch_size,
                       samples_path=None,
                       **kwargs):
    """Evaluate model in offline mode.

    This method first save generated images at local and then load them by
    dataloader.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): PyTorch data loader.
        metrics (list): List of metric objects.
        logger (Logger): logger used to record results of evaluation.
        batch_size (int): Batch size of images fed into metrics.
        basic_table_info (dict): Dictionary containing the basic information \
            of the metric table include training configuration and ckpt.
        samples_path (str): Used to save generated images. If it's none, we'll
            give it a default directory and delete it after finishing the
            evaluation. Default to None.
        kwargs (dict): Other arguments.
    """
    # eval special and recon metric online only
    online_metric_name = ['PPL', 'GaussianKLD']
    for metric in metrics:
        assert metric.name not in online_metric_name, 'Please eval '\
             f'{metric.name} online'

    rank, ws = get_dist_info()

    delete_samples_path = False
    if samples_path:
        mmcv.mkdir_or_exist(samples_path)
    else:
        temp_path = './work_dirs/temp_samples'
        # if temp_path exists, add suffix
        suffix = 1
        samples_path = temp_path
        while os.path.exists(samples_path):
            samples_path = temp_path + '_' + str(suffix)
            suffix += 1
        os.makedirs(samples_path)
        delete_samples_path = True

    # sample images
    num_exist = len(
        list(
            mmcv.scandir(
                samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
    if basic_table_info['num_samples'] > 0:
        max_num_images = basic_table_info['num_samples']
    else:
        max_num_images = max(metric.num_images for metric in metrics)
    num_needed = max(max_num_images - num_exist, 0)

    if num_needed > 0 and rank == 0:
        mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
                       'mmgen')
        # define mmcv progress bar
        pbar = mmcv.ProgressBar(num_needed)

    # if no images, `num_needed` should be zero
    total_batch_size = batch_size * ws
    for begin in range(0, num_needed, total_batch_size):
        end = min(begin + batch_size, max_num_images)
        fakes = model(
            None,
            num_batches=end - begin,
            return_loss=False,
            sample_model=basic_table_info['sample_model'],
            **kwargs)
        global_end = min(begin + total_batch_size, max_num_images)
        if rank == 0:
            pbar.update(global_end - begin)

        # gather generated images
        if ws > 1:
            placeholder = [torch.zeros_like(fakes) for _ in range(ws)]
            dist.all_gather(placeholder, fakes)
            fakes = torch.cat(placeholder, dim=0)

        # save as three-channel
        if fakes.size(1) == 3:
            fakes = fakes[:, [2, 1, 0], ...]
        elif fakes.size(1) == 1:
            fakes = torch.cat([fakes] * 3, dim=1)
        else:
            raise RuntimeError('Generated images must have one or three '
                               'channels in the first dimension, '
                               'not %d' % fakes.size(1))

        if rank == 0:
            for i in range(global_end - begin):
                images = fakes[i:i + 1]
                images = ((images + 1) / 2)
                images = images.clamp_(0, 1)
                image_name = str(num_exist + begin + i) + '.png'
                save_image(images, os.path.join(samples_path, image_name))

    if num_needed > 0 and rank == 0:
        sys.stdout.write('\n')

    # return if only save sampled images
    if len(metrics) == 0:
        return

    # empty cache to release GPU memory
    torch.cuda.empty_cache()
    fake_dataloader = make_vanilla_dataloader(
        samples_path, batch_size, dist=ws > 1)
    for metric in metrics:
        mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
        metric.prepare()
        if rank == 0:
            # prepare for pbar
            total_need = (
                metric.num_real_need + metric.num_fake_need -
                metric.num_real_feeded - metric.num_fake_feeded)
            pbar = mmcv.ProgressBar(total_need)
        # feed in real images
        for data in data_loader:
            # key for unconditional GAN
            if 'real_img' in data:
                reals = data['real_img']
            # key for conditional GAN
            elif 'img' in data:
                reals = data['img']
            else:
                raise KeyError('Cannot found key for images in data_dict. '
                               'Only support `real_img` for unconditional '
                               'datasets and `img` for conditional '
                               'datasets.')

            if reals.shape[1] == 1:
                reals = torch.cat([reals] * 3, dim=1)
            num_left = metric.feed(reals, 'reals')
            if num_left <= 0:
                break
            if rank == 0:
                pbar.update(reals.shape[0] * ws)
        # feed in fake images
        for data in fake_dataloader:
            fakes = data['real_img']
            if fakes.shape[1] == 1:
                fakes = torch.cat([fakes] * 3, dim=1)
            num_left = metric.feed(fakes, 'fakes')
            if num_left <= 0:
                break
            if rank == 0:
                pbar.update(fakes.shape[0] * ws)
        if rank == 0:
            # only call summary at main device
            metric.summary()
            sys.stdout.write('\n')
    if rank == 0:
        table_str = make_metrics_table(basic_table_info['train_cfg'],
                                       basic_table_info['ckpt'],
                                       basic_table_info['sample_model'],
                                       metrics)
        logger.info('\n' + table_str)
        if delete_samples_path:
            shutil.rmtree(samples_path)


@torch.no_grad()
def online_evaluation(model, data_loader, metrics, logger, basic_table_info,
                      batch_size, **kwargs):
    """Evaluate model in online mode.

    This method evaluate model and displays eval progress bar.
    Different form `offline_evaluation`, this function will not save
    the images or read images from disks. Namely, there do not exist any IO
    operations in this function. Thus, in general, `online` mode will achieve a
    faster evaluation. However, this mode will take much more memory cost.
    To be noted that, we only support distributed evaluation for FID and IS
    currently.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): PyTorch data loader.
        metrics (list): List of metric objects.
        logger (Logger): logger used to record results of evaluation.
        batch_size (int): Batch size of images fed into metrics.
        basic_table_info (dict): Dictionary containing the basic information \
            of the metric table include training configuration and ckpt.
        kwargs (dict): Other arguments.
    """
    # separate metrics into special metrics, probabilistic metrics and vanilla
    # metrics.
    # For vanilla metrics, images are generated in a random way, and are
    # shared by these metrics. For special metrics like 'PPL', images are
    # generated in a metric-special way and not shared between different
    # metrics.
    # For reconstruction metrics like 'GaussianKLD', they do not
    # receive images but receive a dict with corresponding probabilistic
    # parameter.

    rank, ws = get_dist_info()

    special_metrics = []
    recon_metrics = []
    vanilla_metrics = []
    special_metric_name = ['PPL']
    recon_metric_name = ['GaussianKLD']
    for metric in metrics:
        if ws > 1:
            assert metric.name in [
                'FID', 'IS'
            ], ('We only support FID and IS for distributed evaluation '
                f'currently, but receive {metric.name}')

        if metric.name in special_metric_name:
            special_metrics.append(metric)
        elif metric.name in recon_metric_name:
            recon_metrics.append(metric)
        else:
            vanilla_metrics.append(metric)

    # define mmcv progress bar
    max_num_images = 0
    for metric in vanilla_metrics + recon_metrics:
        metric.prepare()
        max_num_images = max(max_num_images,
                             metric.num_real_need - metric.num_real_feeded)
    if rank == 0:
        mmcv.print_log(f'Sample {max_num_images} real images for evaluation',
                       'mmgen')
        pbar = mmcv.ProgressBar(max_num_images)

    # avoid `data_loader` is None
    data_loader = [] if data_loader is None else data_loader
    for data in data_loader:
        if 'real_img' in data:
            reals = data['real_img']
        # key for conditional GAN
        elif 'img' in data:
            reals = data['img']
        else:
            raise KeyError('Cannot found key for images in data_dict. '
                           'Only support `real_img` for unconditional '
                           'datasets and `img` for conditional '
                           'datasets.')

        if reals.shape[1] not in [1, 3]:
            raise RuntimeError('real images should have one or three '
                               'channels in the first, '
                               'not % d' % reals.shape[1])
        if reals.shape[1] == 1:
            reals = reals.repeat(1, 3, 1, 1)

        num_feed = 0
        for metric in vanilla_metrics:
            num_feed_ = metric.feed(reals, 'reals')
            num_feed = max(num_feed_, num_feed)
        for metric in recon_metrics:
            kwargs_ = deepcopy(kwargs)
            kwargs_['mode'] = 'reconstruction'
            prob_dict = model(reals, return_loss=False, **kwargs_)
            num_feed_ = metric.feed(prob_dict, 'reals')
            num_feed = max(num_feed_, num_feed)

        if num_feed <= 0:
            break

        if rank == 0:
            pbar.update(num_feed)

    if rank == 0:
        # finish the pbar stdout
        sys.stdout.write('\n')

    # define mmcv progress bar
    max_num_images = 0 if len(vanilla_metrics) == 0 else max(
        metric.num_fake_need for metric in vanilla_metrics)
    if rank == 0:
        mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
                       'mmgen')
        pbar = mmcv.ProgressBar(max_num_images)
    # sampling fake images and directly send them to metrics
    total_batch_size = batch_size * ws
    for _ in range(0, max_num_images, total_batch_size):
        fakes = model(
            None,
            num_batches=batch_size,
            return_loss=False,
            sample_model=basic_table_info['sample_model'],
            **kwargs)

        if fakes.shape[1] not in [1, 3]:
            raise RuntimeError('fakes images should have one or three '
                               'channels in the first, '
                               'not % d' % fakes.shape[1])
        if fakes.shape[1] == 1:
            fakes = torch.cat([fakes] * 3, dim=1)

        for metric in vanilla_metrics:
            # feed in fake images
            metric.feed(fakes, 'fakes')

        if rank == 0:
            pbar.update(total_batch_size)

    if rank == 0:
        # finish the pbar stdout
        sys.stdout.write('\n')

    # feed special metric, we do not consider distributed eval here
    for metric in special_metrics:
        metric.prepare()
        fakedata_iterator = iter(
            metric.get_sampler(model.module, batch_size,
                               basic_table_info['sample_model']))
        mmcv.print_log(
            f'Sample {metric.num_images} samples for evaluating {metric.name}',
            'mmgen')
        pbar = mmcv.ProgressBar(metric.num_images)
        for fakes in fakedata_iterator:
            num_left = metric.feed(fakes, 'fakes')
            pbar.update(fakes.shape[0])
            if num_left <= 0:
                break

        # finish the pbar stdout
        sys.stdout.write('\n')

    if rank == 0:
        for metric in metrics:
            metric.summary()

        table_str = make_metrics_table(basic_table_info['train_cfg'],
                                       basic_table_info['ckpt'],
                                       basic_table_info['sample_model'],
                                       metrics)
        logger.info('\n' + table_str)