metrics.py 57 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
# Copyright (c) OpenMMLab. All rights reserved.
import os
import pickle
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial

import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmcv.runner import get_dist_info
from scipy import linalg, signal
from scipy.stats import entropy
from torchvision import models
from torchvision.models.inception import inception_v3

from mmgen.models.architectures import InceptionV3
from mmgen.models.architectures.common import get_module_device
from mmgen.models.architectures.lpips import PerceptualLoss
from mmgen.models.losses import gaussian_kld
from mmgen.utils import MMGEN_CACHE_DIR
from mmgen.utils.io_utils import download_from_url
from ..registry import METRICS
from .metric_utils import (_f_special_gauss, _hox_downsample,
                           compute_pr_distances, finalize_descriptors,
                           get_descriptors_for_minibatch, get_gaussian_kernel,
                           laplacian_pyramid, slerp)

TERO_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'  # noqa


def load_inception(inception_args, metric):
    """Load Inception Model from given ``inception_args`` and ``metric``. This
    function would try to load Inception under the guidance of 'type' given in
    `inception_args`, if not given, we would try best to load Tero's ones. In
    detailly, we would first try to load the model from disk with the given
    'inception_path', and then try to download the checkpoint from
    'inception_url'. If both method are failed, pytorch version of Inception
    would be loaded.

    Args:
        inception_args (dict): Keyword args for inception net.
        metric (string): Metric to use the Inception. This argument would
            influence the pytorch's Inception loading.
    Returns:
        model (torch.nn.Module): Loaded Inception model.
        style (string): The version of the loaded Inception.
    """

    if not isinstance(inception_args, dict):
        raise TypeError('Receive invalid \'inception_args\': '
                        f'\'{inception_args}\'')

    _inception_args = deepcopy(inception_args)
    inceptoin_type = _inception_args.pop('type', None)

    if torch.__version__ < '1.6.0':
        mmcv.print_log(
            'Current Pytorch Version not support script module, load '
            'Inception Model from torch model zoo. If you want to use '
            'Tero\' script model, please update your Pytorch higher '
            f'than \'1.6\' (now is {torch.__version__})', 'mmgen')
        return _load_inception_torch(_inception_args, metric), 'pytorch'

    # load pytorch version is specific
    if inceptoin_type != 'StyleGAN':
        return _load_inception_torch(_inception_args, metric), 'pytorch'

    # try to load Tero's version
    path = _inception_args.get('inception_path', TERO_INCEPTION_URL)

    # try to parse `path` as web url and download
    if 'http' not in path:
        model = _load_inception_from_path(path)
        if isinstance(model, torch.nn.Module):
            return model, 'StyleGAN'

    # try to parse `path` as path on disk
    model = _load_inception_from_url(path)
    if isinstance(model, torch.nn.Module):
        return model, 'StyleGAN'

    raise RuntimeError('Cannot Load Inception Model, please check the input '
                       f'`inception_args`: {inception_args}')


def _load_inception_from_path(inception_path):
    mmcv.print_log(
        'Try to load Tero\'s Inception Model from '
        f'\'{inception_path}\'.', 'mmgen')
    try:
        model = torch.jit.load(inception_path)
        mmcv.print_log('Load Tero\'s Inception Model successfully.', 'mmgen')
    except Exception as e:
        model = None
        mmcv.print_log(
            'Load Tero\'s Inception Model failed. '
            f'\'{e}\' occurs.', 'mmgen')
    return model


def _load_inception_from_url(inception_url):
    inception_url = inception_url if inception_url else TERO_INCEPTION_URL
    mmcv.print_log(f'Try to download Inception Model from {inception_url}...',
                   'mmgen')
    try:
        path = download_from_url(inception_url, dest_dir=MMGEN_CACHE_DIR)
        mmcv.print_log('Download Finished.')
        return _load_inception_from_path(path)
    except Exception as e:
        mmcv.print_log(f'Download Failed. {e} occurs.')
        return None


def _load_inception_torch(inception_args, metric):
    assert metric in ['FID', 'IS']
    if metric == 'FID':
        inception_model = InceptionV3([3], **inception_args)
    elif metric == 'IS':
        inception_model = inception_v3(pretrained=True, transform_input=False)
        mmcv.print_log(
            'Load Inception V3 Network from Pytorch Model Zoo '
            'for IS calculation. The results can only used '
            'for monitoring purposes. To get more accuracy IS, '
            'please use Tero\'s Inception V3 checkpoints '
            'and use Bicubic Interpolation with Pillow backend '
            'for image resizing. More details may refer to '
            'https://github.com/open-mmlab/mmgeneration/blob/master/docs/en/quick_run.md#is.',  # noqa
            'mmgen')
    return inception_model


def _ssim_for_multi_scale(img1,
                          img2,
                          max_val=255,
                          filter_size=11,
                          filter_sigma=1.5,
                          k1=0.01,
                          k2=0.03):
    """Calculate SSIM (structural similarity) and contrast sensitivity.

    Ref:
    Image quality assessment: From error visibility to structural similarity.

    The results are the same as that of the official released MATLAB code in
    https://ece.uwaterloo.ca/~z70wang/research/ssim/.

    For three-channel images, SSIM is calculated for each channel and then
    averaged.

    This function attempts to match the functionality of ssim_index_new.m by
    Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip

    Args:
        img1 (ndarray): Images with range [0, 255] and order "NHWC".
        img2 (ndarray): Images with range [0, 255] and order "NHWC".
        max_val (int): the dynamic range of the images (i.e., the difference
            between the maximum the and minimum allowed values).
            Default to 255.
        filter_size (int): Size of blur kernel to use (will be reduced for
            small images). Default to 11.
        filter_sigma (float): Standard deviation for Gaussian blur kernel (will
            be reduced for small images). Default to 1.5.
        k1 (float): Constant used to maintain stability in the SSIM calculation
            (0.01 in the original paper). Default to 0.01.
        k2 (float): Constant used to maintain stability in the SSIM calculation
            (0.03 in the original paper). Default to 0.03.

    Returns:
        tuple: Pair containing the mean SSIM and contrast sensitivity between
        `img1` and `img2`.
    """
    if img1.shape != img2.shape:
        raise RuntimeError(
            'Input images must have the same shape (%s vs. %s).' %
            (img1.shape, img2.shape))
    if img1.ndim != 4:
        raise RuntimeError('Input images must have four dimensions, not %d' %
                           img1.ndim)

    img1 = img1.astype(np.float32)
    img2 = img2.astype(np.float32)
    _, height, width, _ = img1.shape

    # Filter size can't be larger than height or width of images.
    size = min(filter_size, height, width)

    # Scale down sigma if a smaller filter size is used.
    sigma = size * filter_sigma / filter_size if filter_size else 0

    if filter_size:
        window = np.reshape(_f_special_gauss(size, sigma), (1, size, size, 1))
        mu1 = signal.fftconvolve(img1, window, mode='valid')
        mu2 = signal.fftconvolve(img2, window, mode='valid')
        sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
        sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
        sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
    else:
        # Empty blur kernel so no need to convolve.
        mu1, mu2 = img1, img2
        sigma11 = img1 * img1
        sigma22 = img2 * img2
        sigma12 = img1 * img2

    mu11 = mu1 * mu1
    mu22 = mu2 * mu2
    mu12 = mu1 * mu2
    sigma11 -= mu11
    sigma22 -= mu22
    sigma12 -= mu12

    # Calculate intermediate values used by both ssim and cs_map.
    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    v1 = 2.0 * sigma12 + c2
    v2 = sigma11 + sigma22 + c2
    ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)),
                   axis=(1, 2, 3))  # Return for each image individually.
    cs = np.mean(v1 / v2, axis=(1, 2, 3))
    return ssim, cs


def ms_ssim(img1,
            img2,
            max_val=255,
            filter_size=11,
            filter_sigma=1.5,
            k1=0.01,
            k2=0.03,
            weights=None):
    """Calculate MS-SSIM (multi-scale structural similarity).

    Ref:
    This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
    Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
    similarity for image quality assessment" (2003).
    Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf

    Author's MATLAB implementation:
    http://www.cns.nyu.edu/~lcv/ssim/msssim.zip

    PGGAN's implementation:
    https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py

    Args:
        img1 (ndarray): Images with range [0, 255] and order "NHWC".
        img2 (ndarray): Images with range [0, 255] and order "NHWC".
        max_val (int): the dynamic range of the images (i.e., the difference
            between the maximum the and minimum allowed values).
            Default to 255.
        filter_size (int): Size of blur kernel to use (will be reduced for
            small images). Default to 11.
        filter_sigma (float): Standard deviation for Gaussian blur kernel (will
            be reduced for small images). Default to 1.5.
        k1 (float): Constant used to maintain stability in the SSIM calculation
            (0.01 in the original paper). Default to 0.01.
        k2 (float): Constant used to maintain stability in the SSIM calculation
            (0.03 in the original paper). Default to 0.03.
        weights (list): List of weights for each level; if none, use five
            levels and the weights from the original paper. Default to None.

    Returns:
        float: MS-SSIM score between `img1` and `img2`.
    """
    if img1.shape != img2.shape:
        raise RuntimeError(
            'Input images must have the same shape (%s vs. %s).' %
            (img1.shape, img2.shape))
    if img1.ndim != 4:
        raise RuntimeError('Input images must have four dimensions, not %d' %
                           img1.ndim)

    # Note: default weights don't sum to 1.0 but do match the paper / matlab
    # code.
    weights = np.array(
        weights if weights else [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
    levels = weights.size
    im1, im2 = [x.astype(np.float32) for x in [img1, img2]]
    mssim = []
    mcs = []
    for _ in range(levels):
        ssim, cs = _ssim_for_multi_scale(
            im1,
            im2,
            max_val=max_val,
            filter_size=filter_size,
            filter_sigma=filter_sigma,
            k1=k1,
            k2=k2)
        mssim.append(ssim)
        mcs.append(cs)
        im1, im2 = [_hox_downsample(x) for x in [im1, im2]]

    # Clip to zero. Otherwise we get NaNs.
    mssim = np.clip(np.asarray(mssim), 0.0, np.inf)
    mcs = np.clip(np.asarray(mcs), 0.0, np.inf)

    # Average over images only at the end.
    return np.mean(
        np.prod(mcs[:-1, :]**weights[:-1, np.newaxis], axis=0) *
        (mssim[-1, :]**weights[-1]))


def sliced_wasserstein(distribution_a,
                       distribution_b,
                       dir_repeats=4,
                       dirs_per_repeat=128):
    r"""sliced Wasserstein distance of two sets of patches.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py  # noqa

    Args:
        distribution_a (Tensor): Descriptors of first distribution.
        distribution_b (Tensor): Descriptors of second distribution.
        dir_repeats (int): The number of projection times. Default to 4.
        dirs_per_repeat (int): The number of directions per projection.
            Default to 128.

    Returns:
        float: sliced Wasserstein distance.
    """
    if torch.cuda.is_available():
        distribution_b = distribution_b.cuda()
    assert distribution_a.ndim == 2
    assert distribution_a.shape == distribution_b.shape
    assert dir_repeats > 0 and dirs_per_repeat > 0
    distribution_a = distribution_a.to(distribution_b.device)
    results = []
    for _ in range(dir_repeats):
        dirs = torch.randn(distribution_a.shape[1], dirs_per_repeat)
        dirs /= torch.sqrt(torch.sum((dirs**2), dim=0, keepdim=True))
        dirs = dirs.to(distribution_b.device)
        proj_a = torch.matmul(distribution_a, dirs)
        proj_b = torch.matmul(distribution_b, dirs)
        # To save cuda memory, we perform sort in cpu
        proj_a, _ = torch.sort(proj_a.cpu(), dim=0)
        proj_b, _ = torch.sort(proj_b.cpu(), dim=0)
        dists = torch.abs(proj_a - proj_b)
        results.append(torch.mean(dists).item())
    torch.cuda.empty_cache()
    return sum(results) / dir_repeats


class Metric(ABC):
    """The abstract base class of metrics. Basically, we split calculation into
    three steps. First, we initialize the metric object and do some
    preparation. Second, we will feed the real and fake images into metric
    object batch by batch, and we calculate intermediate results of these
    batches. Finally, We use these intermediate results to summarize the final
    result. And the result as a string can be obtained by property
    'result_str'.

    Args:
        num_images (int): The number of real/fake images needed to calculate
            metric.
        image_shape (tuple): Shape of the real/fake images with order "CHW".
    """

    def __init__(self, num_images, image_shape=None):
        self.num_images = num_images
        self.image_shape = image_shape
        self.num_real_need = num_images
        self.num_fake_need = num_images
        self.num_real_feeded = 0  # record of the fed real images
        self.num_fake_feeded = 0  # record of the fed fake images
        self._result_str = None  # string of metric result

    @property
    def result_str(self):
        """Get results in string format.

        Returns:
            str: results in string format
        """
        if not self._result_str:
            self.summary()
            return self._result_str

        return self._result_str

    def feed(self, batch, mode):
        """Feed a image batch into metric calculator and perform intermediate
        operation in 'feed_op' function.

        Args:
            batch (Tensor | dict): Images or dict to be fed into
                metric object. If ``Tensor`` is passed, the order of ``Tensor``
                should be "NCHW". If ``dict`` is passed, each term in the
                ``dict`` are ``Tensor`` with order "NCHW".
            mode (str): Mark the batch as real or fake images. Value can be
                'reals' or 'fakes',
        """
        _, ws = get_dist_info()
        if mode == 'reals':
            if self.num_real_feeded == self.num_real_need:
                return 0

            if isinstance(batch, dict):
                batch_size = [v for v in batch.values()][0].shape[0]
                end = min(batch_size,
                          self.num_real_need - self.num_real_feeded)
                batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
            else:
                batch_size = batch.shape[0]
                end = min(batch_size,
                          self.num_real_need - self.num_real_feeded)
                batch_to_feed = batch[:end, ...]

            global_end = min(batch_size * ws,
                             self.num_real_need - self.num_real_feeded)
            self.feed_op(batch_to_feed, mode)
            self.num_real_feeded += global_end
            return end

        elif mode == 'fakes':
            if self.num_fake_feeded == self.num_fake_need:
                return 0

            batch_size = batch.shape[0]
            end = min(batch_size, self.num_fake_need - self.num_fake_feeded)
            if isinstance(batch, dict):
                batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
            else:
                batch_to_feed = batch[:end, ...]

            global_end = min(batch_size * ws,
                             self.num_fake_need - self.num_fake_feeded)
            self.feed_op(batch_to_feed, mode)
            self.num_fake_feeded += global_end
            return end
        else:
            raise ValueError(
                'The expected mode should be set to \'reals\' or \'fakes\','
                f'but got \'{mode}\'')

    def check(self):
        """Check the numbers of image."""
        assert self.num_real_feeded == self.num_fake_feeded == self.num_images

    @abstractmethod
    def prepare(self, *args, **kwargs):
        """please implement in subclass."""

    @abstractmethod
    def feed_op(self, batch, mode):
        """please implement in subclass."""

    @abstractmethod
    def summary(self):
        """please implement in subclass."""


@METRICS.register_module()
class FID(Metric):
    """FID metric.

    In this metric, we calculate the distance between real distributions and
    fake distributions. The distributions are modeled by the real samples and
    fake samples, respectively.

    `Inception_v3` is adopted as the feature extractor, which is widely used in
    StyleGAN and BigGAN.

    Args:
        num_images (int): The number of images to be tested.
        image_shape (tuple[int], optional): Image shape. Defaults to None.
        inception_pkl (str, optional): Path to reference inception pickle file.
            If `None`, the statistical value of real distribution will be
            calculated at running time. Defaults to None.
        bgr2rgb (bool, optional): If True, reformat the BGR image to RGB
            format. Defaults to True.
        inception_args (dict, optional): Keyword args for inception net.
            Defaults to `dict(normalize_input=False)`.
    """
    name = 'FID'

    def __init__(self,
                 num_images,
                 image_shape=None,
                 inception_pkl=None,
                 bgr2rgb=True,
                 inception_args=dict(normalize_input=False)):
        super().__init__(num_images, image_shape=image_shape)
        self.inception_pkl = inception_pkl
        self.real_feats = []
        self.fake_feats = []
        self.real_mean = None
        self.real_cov = None
        self.bgr2rgb = bgr2rgb
        self.device = 'cpu'

        self.inception_net, self.inception_style = load_inception(
            inception_args, 'FID')

        if torch.cuda.is_available():
            self.inception_net = self.inception_net.cuda()
            self.device = 'cuda'
        self.inception_net.eval()

        mmcv.print_log(f'FID: Adopt Inception in {self.inception_style} style',
                       'mmgen')

    def prepare(self):
        """Prepare for evaluating models with this metric."""
        # if `inception_pkl` is provided, read mean and cov stat
        if self.inception_pkl is not None and mmcv.is_filepath(
                self.inception_pkl):
            with open(self.inception_pkl, 'rb') as f:
                reference = pickle.load(f)
                self.real_mean = reference['mean']
                self.real_cov = reference['cov']
                mmcv.print_log(
                    f'Load reference inception pkl from {self.inception_pkl}',
                    'mmgen')
            self.num_real_feeded = self.num_images

    @torch.no_grad()
    def feed_op(self, batch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        if self.bgr2rgb:
            batch = batch[:, [2, 1, 0]]
        batch = batch.to(self.device)

        if self.inception_style == 'StyleGAN':
            batch = (batch * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            feat = self.inception_net(batch, return_features=True)
        else:
            feat = self.inception_net(batch)[0].view(batch.shape[0], -1)

        # gather all of images if using distributed training
        if dist.is_initialized():
            ws = dist.get_world_size()
            placeholder = [torch.zeros_like(feat) for _ in range(ws)]
            dist.all_gather(placeholder, feat)
            feat = torch.cat(placeholder, dim=0)

        # in distributed training, we only collect features at rank-0.
        if (dist.is_initialized()
                and dist.get_rank() == 0) or not dist.is_initialized():
            if mode == 'reals':
                self.real_feats.append(feat.cpu())
            elif mode == 'fakes':
                self.fake_feats.append(feat.cpu())
            else:
                raise ValueError(
                    f"The expected mode should be set to 'reals' or 'fakes,\
                    but got '{mode}'")

    @staticmethod
    def _calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
        """Refer to the implementation from:

        https://github.com/rosinality/stylegan2-pytorch/blob/master/fid.py#L34
        """
        cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)

        if not np.isfinite(cov_sqrt).all():
            print('product of cov matrices is singular')
            offset = np.eye(sample_cov.shape[0]) * eps
            cov_sqrt = linalg.sqrtm(
                (sample_cov + offset) @ (real_cov + offset))

        if np.iscomplexobj(cov_sqrt):
            if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
                m = np.max(np.abs(cov_sqrt.imag))

                raise ValueError(f'Imaginary component {m}')

            cov_sqrt = cov_sqrt.real

        mean_diff = sample_mean - real_mean
        mean_norm = mean_diff @ mean_diff

        trace = np.trace(sample_cov) + np.trace(
            real_cov) - 2 * np.trace(cov_sqrt)

        fid = mean_norm + trace

        return fid, mean_norm, trace

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        Returns:
            dict | list: Summarized results.
        """
        # calculate reference inception stat
        if self.real_mean is None:
            feats = torch.cat(self.real_feats, dim=0)
            assert feats.shape[0] >= self.num_images
            feats = feats[:self.num_images]
            feats_np = feats.numpy()
            self.real_mean = np.mean(feats_np, 0)
            self.real_cov = np.cov(feats_np, rowvar=False)

        # calculate fake inception stat
        fake_feats = torch.cat(self.fake_feats, dim=0)
        assert fake_feats.shape[0] >= self.num_images
        fake_feats = fake_feats[:self.num_images]
        fake_feats_np = fake_feats.numpy()
        fake_mean = np.mean(fake_feats_np, 0)
        fake_cov = np.cov(fake_feats_np, rowvar=False)

        # calculate distance between real and fake statistics
        fid, mean, cov = self._calc_fid(fake_mean, fake_cov, self.real_mean,
                                        self.real_cov)

        # results for print/table
        self._result_str = (f'{fid:.4f} ({mean:.5f}/{cov:.5f})')
        # results for log_buffer
        self._result_dict = dict(fid=fid, fid_mean=mean, fid_cov=cov)

        return fid, mean, cov

    def clear_fake_data(self):
        """Clear fake data."""
        self.fake_feats = []
        self.num_fake_feeded = 0

    def clear(self, clear_reals=False):
        """Clear data buffers.

        Args:
            clear_reals (bool, optional): Whether to clear real data.
                Defaults to False.
        """
        self.clear_fake_data()
        if clear_reals:
            self.real_feats = []
            self.num_real_feeded = 0


@METRICS.register_module()
class MS_SSIM(Metric):
    """MS-SSIM (Multi-Scale Structure Similarity) metric.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py # noqa

    Args:
        num_images (int): The number of evaluated generated samples.
        image_shape (tuple, optional): Image shape in order "CHW". Defaults to
            None.
    """
    name = 'MS-SSIM'

    def __init__(self, num_images, image_shape=None):
        super().__init__(num_images, image_shape)
        assert num_images % 2 == 0
        self.num_pairs = num_images // 2

    def prepare(self):
        """Prepare for evaluating models with this metric."""
        self.sum = 0.0

    @torch.no_grad()
    def feed_op(self, minibatch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        if mode == 'reals':
            return
        minibatch = ((minibatch + 1) / 2)
        minibatch = minibatch.clamp_(0, 1)
        half1 = minibatch[0::2].cpu().data.numpy().transpose((0, 2, 3, 1))
        half1 = (half1 * 255).astype('uint8')
        half2 = minibatch[1::2].cpu().data.numpy().transpose((0, 2, 3, 1))
        half2 = (half2 * 255).astype('uint8')
        score = ms_ssim(half1, half2)
        self.sum += score * (minibatch.shape[0] // 2)

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        Returns:
            dict | list: Summarized results.
        """
        self.check()
        avg = self.sum / self.num_pairs
        self._result_str = str(round(avg.item(), 4))
        return avg


@METRICS.register_module()
class SWD(Metric):
    """SWD (Sliced Wasserstein distance) metric. We calculate the SWD of two
    sets of images in the following way. In every 'feed', we obtain the
    Laplacian pyramids of every images and extract patches from the Laplacian
    pyramids as descriptors. In 'summary', we normalize these descriptors along
    channel, and reshape them so that we can use these descriptors to represent
    the distribution of real/fake images. And we can calculate the sliced
    Wasserstein distance of the real and fake descriptors as the SWD of the
    real and fake images.

    Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa

    Args:
        num_images (int): The number of evaluated generated samples.
        image_shape (tuple): Image shape in order "CHW".
    """
    name = 'SWD'

    def __init__(self, num_images, image_shape):
        super().__init__(num_images, image_shape)

        self.nhood_size = 7  # height and width of the extracted patches
        self.nhoods_per_image = 128  # number of extracted patches per image
        self.dir_repeats = 4  # times of sampling directions
        self.dirs_per_repeat = 128  # number of directions per sampling
        self.resolutions = []
        res = image_shape[1]
        while res >= 16 and len(self.resolutions) < 4:
            self.resolutions.append(res)
            res //= 2
        self.n_pyramids = len(self.resolutions)

    def prepare(self):
        """Prepare for evaluating models with this metric."""
        self.real_descs = [[] for res in self.resolutions]
        self.fake_descs = [[] for res in self.resolutions]
        self.gaussian_k = get_gaussian_kernel()

    @torch.no_grad()
    def feed_op(self, minibatch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        assert minibatch.shape[1:] == self.image_shape
        if mode == 'reals':
            real_pyramid = laplacian_pyramid(minibatch, self.n_pyramids - 1,
                                             self.gaussian_k)
            # lod: layer_of_descriptors
            for lod, level in enumerate(real_pyramid):
                desc = get_descriptors_for_minibatch(level, self.nhood_size,
                                                     self.nhoods_per_image)
                self.real_descs[lod].append(desc)
        elif mode == 'fakes':
            fake_pyramid = laplacian_pyramid(minibatch, self.n_pyramids - 1,
                                             self.gaussian_k)
            for lod, level in enumerate(fake_pyramid):
                desc = get_descriptors_for_minibatch(level, self.nhood_size,
                                                     self.nhoods_per_image)
                self.fake_descs[lod].append(desc)
        else:
            raise ValueError(f'{mode} is not a implemented feed mode.')

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        Returns:
            dict | list: Summarized results.
        """
        self.check()
        real_descs = [finalize_descriptors(d) for d in self.real_descs]
        fake_descs = [finalize_descriptors(d) for d in self.fake_descs]
        del self.real_descs
        del self.fake_descs
        distance = [
            sliced_wasserstein(dreal, dfake, self.dir_repeats,
                               self.dirs_per_repeat)
            for dreal, dfake in zip(real_descs, fake_descs)
        ]
        del real_descs
        del fake_descs
        distance = [d * 1e3 for d in distance]  # multiply by 10^3
        result = distance + [np.mean(distance)]
        self._result_str = ', '.join([str(round(d, 2)) for d in result])
        return result


@METRICS.register_module()
class PR(Metric):
    r"""Improved Precision and recall metric.

        In this metric, we draw real and generated samples respectively, and
        embed them into a high-dimensional feature space using a pre-trained
        classifier network. We use these features to estimate the corresponding
        manifold. We obtain the estimation by calculating pairwise Euclidean
        distances between all feature vectors in the set and, for each feature
        vector, construct a hypersphere with radius equal to the distance to its
        kth nearest neighbor. Together, these hyperspheres define a volume in
        the feature space that serves as an estimate of the true manifold.
        Precision is quantified by querying for each generated image whether
        the image is within the estimated manifold of real images.
        Symmetrically, recall is calculated by querying for each real image
        whether the image is within estimated manifold of generated image.

        Ref: https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py  # noqa

        Note that we highly recommend that users should download the vgg16
        script module from the following address. Then, the `vgg16_script` can
        be set with user's local path. If not given, we will use the vgg16 from
        pytorch model zoo. However, this may bring significant different in the
        final results.

        Tero's vgg16: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt

        Args:
            num_images (int): The number of evaluated generated samples.
            image_shape (tuple): Image shape in order "CHW". Defaults to None.
            num_real_need (int | None, optional): The number of real images.
                Defaults to None.
            full_dataset (bool, optional): Whether to use full dataset for
                evaluation. Defaults to False.
            k (int, optional): Kth nearest parameter. Defaults to 3.
            bgr2rgb (bool, optional): Whether to change the order of image
                channel. Defaults to True.
            vgg16_script (str, optional): Path for the Tero's vgg16 module.
                Defaults to 'work_dirs/cache/vgg16.pt'.
            row_batch_size (int, optional): The batch size of row data.
                Defaults to 10000.
            col_batch_size (int, optional): The batch size of col data.
                Defaults to 10000.
        """
    name = 'PR'

    def __init__(self,
                 num_images,
                 image_shape=None,
                 num_real_need=None,
                 full_dataset=False,
                 k=3,
                 bgr2rgb=True,
                 vgg16_script='work_dirs/cache/vgg16.pt',
                 row_batch_size=10000,
                 col_batch_size=10000):
        super().__init__(num_images, image_shape)
        mmcv.print_log('loading vgg16 for improved precision and recall...',
                       'mmgen')
        if os.path.isfile(vgg16_script):
            self.vgg16 = torch.jit.load('work_dirs/cache/vgg16.pt').eval()
            self.use_tero_scirpt = True
        else:
            mmcv.print_log(
                'Cannot load Tero\'s script module. Use official '
                'vgg16 instead', 'mmgen')
            self.vgg16 = models.vgg16(pretrained=True).eval()
            self.use_tero_scirpt = False
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.vgg16 = self.vgg16.cuda()
            self.device = 'cuda'
        self.k = k

        self.bgr2rgb = bgr2rgb
        self.full_dataset = full_dataset
        self.row_batch_size = row_batch_size
        self.col_batch_size = col_batch_size
        if num_real_need:
            self.num_real_need = num_real_need

        if self.full_dataset:
            self.num_real_need = 10000000

    def prepare(self, *args, **kwargs):
        """Prepare for evaluating models with this metric."""
        self.features_of_reals = []
        self.features_of_fakes = []

    @torch.no_grad()
    def feed_op(self, batch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        batch = batch.to(self.device)
        if self.bgr2rgb:
            batch = batch[:, [2, 1, 0], ...]
        if self.use_tero_scirpt:
            batch = (batch * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        if mode == 'reals':
            self.features_of_reals.append(self.extract_features(batch))
        elif mode == 'fakes':
            self.features_of_fakes.append(self.extract_features(batch))
        else:
            raise ValueError(f'{mode} is not a implemented feed mode.')

    def check(self):
        if not self.full_dataset:
            assert (self.num_real_feeded == self.num_real_need
                    and self.num_fake_feeded == self.num_fake_need)
        else:
            assert self.num_fake_feeded == self.num_fake_need
            mmcv.print_log(
                f'Test for the full dataset with {self.num_real_feeded}'
                ' real images', 'mmgen')

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        Returns:
            dict | list: Summarized results.
        """
        self.check()

        real_features = torch.cat(self.features_of_reals)
        gen_features = torch.cat(self.features_of_fakes)

        self._result_dict = {}
        rank, ws = get_dist_info()

        for name, manifold, probes in [
            ('precision', real_features, gen_features),
            ('recall', gen_features, real_features)
        ]:
            kth = []
            for manifold_batch in manifold.split(self.row_batch_size):
                distance = compute_pr_distances(
                    row_features=manifold_batch,
                    col_features=manifold,
                    num_gpus=ws,
                    rank=rank,
                    col_batch_size=self.col_batch_size)
                kth.append(
                    distance.to(torch.float32).kthvalue(self.k + 1).values.
                    to(torch.float16) if rank == 0 else None)
            kth = torch.cat(kth) if rank == 0 else None
            pred = []
            for probes_batch in probes.split(self.row_batch_size):
                distance = compute_pr_distances(
                    row_features=probes_batch,
                    col_features=manifold,
                    num_gpus=ws,
                    rank=rank,
                    col_batch_size=self.col_batch_size)
                pred.append((distance <= kth).any(
                    dim=1) if rank == 0 else None)
            self._result_dict[name] = float(
                torch.cat(pred).to(torch.float32).mean() if rank ==
                0 else 'nan')

        precision = self._result_dict['precision']
        recall = self._result_dict['recall']
        self._result_str = f'precision: {precision}, recall:{recall}'
        return self._result_dict

    def extract_features(self, images):
        """Extracting image features.

        Args:
            images (torch.Tensor): Images tensor.
        Returns:
            torch.Tensor: Vgg16 features of input images.
        """
        if self.use_tero_scirpt:
            feature = self.vgg16(images, return_features=True)
        else:
            batch = F.interpolate(images, size=(224, 224))
            before_fc = self.vgg16.features(batch)
            before_fc = before_fc.view(-1, 7 * 7 * 512)
            feature = self.vgg16.classifier[:4](before_fc)

        return feature


@METRICS.register_module()
class IS(Metric):
    """IS (Inception Score) metric.

    The images are split into groups, and the inception score is calculated
    on each group of images, then the mean and standard deviation of the score
    is reported. The calculation of the inception score on a group of images
    involves first using the inception v3 model to calculate the conditional
    probability for each image (p(y|x)). The marginal probability is then
    calculated as the average of the conditional probabilities for the images
    in the group (p(y)). The KL divergence is then calculated for each image as
    the conditional probability multiplied by the log of the conditional
    probability minus the log of the marginal probability. The KL divergence is
    then summed over all images and averaged over all classes and the exponent
    of the result is calculated to give the final score.

    Ref: https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py  # noqa

    Note that we highly recommend that users should download the Inception V3
    script module from the following address. Then, the `inception_pkl` can
    be set with user's local path. If not given, we will use the Inception V3
    from pytorch model zoo. However, this may bring significant different in
    the final results.

    Tero's Inception V3: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt  # noqa

    Args:
        num_images (int): The number of evaluated generated samples.
        image_shape (tuple, optional): Image shape in order "CHW". Defaults to
            None.
        bgr2rgb (bool, optional): If True, reformat the BGR image to RGB
            format. In default, our model generate images in the BGR order.
            Thus, we use `True` as the default behavior. Please switch to
            `False`, if the input is in the `RGB` order. Defaults to True.
        resize (bool, optional): Whether resize image to 299x299. Defaults to
            True.
        splits (int, optional): The number of groups. Defaults to 10.
        use_pil_resize (bool, optional): Whether use Bicubic interpolation with
            Pillow's backend. If set as True, the evaluation process may be a
            little bit slow, but achieve a more accurate IS result. Defaults
            to False.
        inception_args (dict, optional): Keyword args for inception net.
            Defaults to ``dict(type='StyleGAN', inception_path=INCEPTION_URL)``.
    """
    name = 'IS'

    def __init__(self,
                 num_images,
                 image_shape=None,
                 bgr2rgb=True,
                 resize=True,
                 splits=10,
                 use_pil_resize=True,
                 inception_args=dict(
                     type='StyleGAN', inception_path=TERO_INCEPTION_URL)):
        super().__init__(num_images, image_shape)
        mmcv.print_log('Loading Inception V3 for IS...', 'mmgen')

        model, style = load_inception(inception_args, 'IS')

        self.inception_model = model
        self.use_tero_script = style == 'StyleGAN'

        self.num_real_feeded = self.num_images
        self.resize = resize
        self.splits = splits
        self.bgr2rgb = bgr2rgb
        self.use_pil_resize = use_pil_resize
        self._pil_resize_warned = False

        self.device = 'cpu'
        if torch.cuda.is_available():
            self.inception_model = self.inception_model.cuda()
            self.device = 'cuda'
        self.inception_model.eval()

    def pil_resize(self, x):
        """Apply Bicubic interpolation with Pillow backend. Before and after
        interpolation operation, we have to perform a type conversion between
        torch.tensor and PIL.Image, and these operations make resize process a
        bit slow.

        Args:
            x (Tensor): Input tensor, should have four dimension and
                        range in [-1, 1].

        Returns:
            torch.FloatTensor: Resized tensor.
        """
        if not self._pil_resize_warned:
            mmcv.print_log(
                '`use_pil_resize` is set as True, apply Bicubic '
                'interpolation with Pillow backend. We perform '
                'type conversion between torch.tensor and '
                'PIL.Image in this function and make this process '
                'a little bit slow.', 'mmgen')
            self._pil_resize_warned = True

        from PIL import Image
        if x.ndim != 4:
            raise ValueError('Input images should have 4 dimensions, '
                             'here receive input with {} '
                             'dimensions.'.format(x.ndim))

        x = (x.clone() * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        x_np = [x_.permute(1, 2, 0).detach().cpu().numpy() for x_ in x]
        x_pil = [Image.fromarray(x_).resize((299, 299)) for x_ in x_np]
        x_ten = torch.cat(
            [torch.FloatTensor(np.array(x_)[None, ...]) for x_ in x_pil])
        x_ten = (x_ten / 127.5 - 1).to(torch.float)
        return x_ten.permute(0, 3, 1, 2)

    def get_pred(self, x):
        """Get prediction from inception model.

        Args:
            x (Tensor): Input tensor.

        Returns:
            np.array: Inception score.
        """
        if self.use_tero_script:
            x = self.inception_model(x, no_output_bias=True)
        else:
            # specify the dimension to avoid warning
            x = F.softmax(self.inception_model(x), dim=1)
        return x

    def prepare(self):
        """Prepare for evaluating models with this metric."""
        self.preds = []

    @torch.no_grad()
    def feed_op(self, batch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        if mode == 'reals':
            pass
        elif mode == 'fakes':
            if self.bgr2rgb:
                batch = batch[:, [2, 1, 0], ...]
            if self.resize:
                if self.use_pil_resize:
                    batch = self.pil_resize(batch)
                else:
                    batch = F.interpolate(
                        batch, size=(299, 299), mode='bilinear')
            if self.use_tero_script:
                batch = (batch * 127.5 + 128).clamp(0, 255).to(torch.uint8)

            batch = batch.to(self.device)

            # get prediction
            pred = self.get_pred(batch)

            if dist.is_initialized():
                ws = dist.get_world_size()
                placeholder = [torch.zeros_like(pred) for _ in range(ws)]
                dist.all_gather(placeholder, pred)
                pred = torch.cat(placeholder, dim=0)

            # in distributed training, we only collect features at rank-0.
            if (dist.is_initialized()
                    and dist.get_rank() == 0) or not dist.is_initialized():
                self.preds.append(pred.cpu().numpy())
        else:
            raise ValueError(f'{mode} is not a implemented feed mode.')

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        TODO: support `master_only`

        Returns:
            dict | list: Summarized results.
        """
        split_scores = []
        self.preds = np.concatenate(self.preds, axis=0)
        # check for the size
        assert self.preds.shape[0] >= self.num_images
        self.preds = self.preds[:self.num_images]
        for k in range(self.splits):
            part = self.preds[k * (self.num_images // self.splits):(k + 1) *
                              (self.num_images // self.splits), :]
            py = np.mean(part, axis=0)
            scores = []
            for i in range(part.shape[0]):
                pyx = part[i, :]
                scores.append(entropy(pyx, py))
            split_scores.append(np.exp(np.mean(scores)))

        mean, std = np.mean(split_scores), np.std(split_scores)
        # results for print/table
        self._result_str = f'mean: {mean:.3f}, std: {std:.3f}'
        # results for log_buffer
        self._result_dict = {'is': mean, 'is_std': std}
        return mean, std

    def clear_fake_data(self):
        """Clear fake data."""
        self.preds = []
        self.num_fake_feeded = 0

    def clear(self, clear_reals=False):
        """Clear data buffers.

        Args:
            clear_reals (bool, optional): Whether to clear real data.
                Defaults to False.
        """
        self.clear_fake_data()


@METRICS.register_module()
class PPL(Metric):
    r"""Perceptual path length.

        Measure the difference between consecutive images (their VGG16
        embeddings) when interpolating between two random inputs. Drastic
        changes mean that multiple features have changed together and that
        they might be entangled.

        Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py # noqa

        Args:
            num_images (int): The number of evaluated generated samples.
            image_shape (tuple, optional): Image shape in order "CHW". Defaults
                to None.
            crop (bool, optional): Whether crop images. Defaults to True.
            epsilon (float, optional): Epsilon parameter for path sampling.
                Defaults to 1e-4.
            space (str, optional): Latent space. Defaults to 'W'.
            sampling (str, optional): Sampling mode, whether sampling in full
                path or endpoints. Defaults to 'end'.
            latent_dim (int, optional): Latent dimension of input noise.
                Defaults to 512.
    """
    name = 'PPL'

    def __init__(self,
                 num_images,
                 image_shape=None,
                 crop=True,
                 epsilon=1e-4,
                 space='W',
                 sampling='end',
                 latent_dim=512):
        super().__init__(num_images, image_shape=image_shape)
        self.crop = crop
        self.epsilon = epsilon
        self.space = space
        self.sampling = sampling
        self.latent_dim = latent_dim
        self.num_images = num_images * 2
        self.num_real_feeded = self.num_images

    def prepare(self):
        """Prepare for evaluating models with this metric."""
        self.dist_list = []

    @torch.no_grad()
    def feed_op(self, minibatch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        if mode == 'reals':
            return
        # use minibatch's device type to initialize a lpips calculator
        if not hasattr(self, 'percept'):
            self.percept = PerceptualLoss(
                use_gpu=minibatch.device.type.startswith('cuda'))
        # crop and resize images
        if self.crop:
            c = minibatch.shape[2] // 8
            minibatch = minibatch[:, :, c * 3:c * 7, c * 2:c * 6]

        factor = minibatch.shape[2] // 256
        if factor > 1:
            minibatch = F.interpolate(
                minibatch,
                size=(256, 256),
                mode='bilinear',
                align_corners=False)
        # calculator and store lpips score
        distance = self.percept(minibatch[::2], minibatch[1::2]).view(
            minibatch.shape[0] // 2) / (
                self.epsilon**2)
        self.dist_list.append(distance.to('cpu').numpy())

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        Returns:
            dict | list: Summarized results.
        """
        distances = np.concatenate(self.dist_list, 0)
        lo = np.percentile(distances, 1, interpolation='lower')
        hi = np.percentile(distances, 99, interpolation='higher')
        filtered_dist = np.extract(
            np.logical_and(lo <= distances, distances <= hi), distances)
        ppl_score = filtered_dist.mean()
        self._result_str = f'{ppl_score:.1f}'
        return ppl_score

    def get_sampler(self, model, batch_size, sample_model):
        """Get sampler for sampling along the path.

        Args:
            model (nn.Module): Generative model.
            batch_size (int): Sampling batch size.
            sample_model (str): Which model you want to use. ['ema',
                'orig']. Defaults to 'ema'.

        Returns:
            Object: A sampler for calculating path length regularization.
        """
        if sample_model == 'ema':
            generator = model.generator_ema
        else:
            generator = model.generator
        ppl_sampler = PPLSampler(generator, self.num_images, batch_size,
                                 self.space, self.sampling, self.epsilon,
                                 self.latent_dim)
        return ppl_sampler


class PPLSampler:
    """StyleGAN series generator's sampling iterator for PPL metric.

    Args:
        generator (nn.Module): StyleGAN series' generator.
        num_images (int): The number of evaluated generated samples.
        batch_size (int): Batch size of generated images.
        space (str, optional): Latent space. Defaults to 'W'.
        sampling (str, optional): Sampling mode, whether sampling in full
            path or endpoints. Defaults to 'end'.
        epsilon (float, optional): Epsilon parameter for path sampling.
            Defaults to 1e-4.
        latent_dim (int, optional): Latent dimension of input noise.
            Defaults to 512.
    """

    def __init__(self,
                 generator,
                 num_images,
                 batch_size,
                 space='W',
                 sampling='end',
                 epsilon=1e-4,
                 latent_dim=512):
        assert space in ['Z', 'W']
        assert sampling in ['full', 'end']
        n_batch = num_images // batch_size

        resid = num_images - (n_batch * batch_size)
        self.batch_sizes = [batch_size] * n_batch + ([resid]
                                                     if resid > 0 else [])
        self.device = get_module_device(generator)
        self.generator = generator
        self.latent_dim = latent_dim
        self.space = space
        self.sampling = sampling
        self.epsilon = epsilon

    def __iter__(self):
        self.idx = 0
        return self

    @torch.no_grad()
    def __next__(self):
        if self.idx >= len(self.batch_sizes):
            raise StopIteration
        batch = self.batch_sizes[self.idx]
        injected_noise = self.generator.make_injected_noise()
        inputs = torch.randn([batch * 2, self.latent_dim], device=self.device)
        if self.sampling == 'full':
            lerp_t = torch.rand(batch, device=self.device)
        else:
            lerp_t = torch.zeros(batch, device=self.device)

        if self.space == 'W':
            assert hasattr(self.generator, 'style_mapping')
            latent = self.generator.style_mapping(inputs)
            latent_t0, latent_t1 = latent[::2], latent[1::2]
            latent_e0 = torch.lerp(latent_t0, latent_t1, lerp_t[:, None])
            latent_e1 = torch.lerp(latent_t0, latent_t1,
                                   lerp_t[:, None] + self.epsilon)
            latent_e = torch.stack([latent_e0, latent_e1],
                                   1).view(*latent.shape)
            image = self.generator([latent_e],
                                   input_is_latent=True,
                                   injected_noise=injected_noise)
        else:
            latent_t0, latent_t1 = inputs[::2], inputs[1::2]
            latent_e0 = slerp(latent_t0, latent_t1, lerp_t[:, None])
            latent_e1 = slerp(latent_t0, latent_t1,
                              lerp_t[:, None] + self.epsilon)
            latent_e = torch.stack([latent_e0, latent_e1],
                                   1).view(*inputs.shape)
            image = self.generator([latent_e],
                                   input_is_latent=False,
                                   injected_noise=injected_noise)

        self.idx += 1
        return image


@METRICS.register_module()
class GaussianKLD(Metric):
    r"""Gaussian KLD (Kullback-Leibler divergence) metric. We calculate the
    KLD between two gaussian distribution via `mean` and `log_variance`.
    The passed batch should be a dict instance and contain ``mean_pred``,
    ``mean_target``, ``logvar_pred``, ``logvar_target``.
    When call ``feed`` operation, only ``reals`` mode is needed,

    The calculation of KLD can be formulated as:

    .. math::
        :nowrap:

        \begin{align}
        KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
            &= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
            \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} -
            \frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
            &= \log{\frac{\sigma_2}{\sigma_1}} +
            \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
        \end{align}

    where `p` and `q` denote target and predicted distribution respectively.

    Args:
        num_images (int): The number of samples to be tested.
        base (str, optional): The log base of calculated KLD. Support
            ``'e'`` and ``'2'``. Defaults to ``'e'``.
        reduction (string, optional): Specifies the reduction to apply to the
            output. Support ``'batchmean'``, ``'sum'`` and ``'mean'``. If
            ``reduction == 'batchmean'``, the sum of the output will be divided
            by batchsize. If ``reduction == 'sum'``, the output will be summed.
            If ``reduction == 'mean'``, the output will be divided by the
            number of elements in the output. Defaults to ``'batchmean'``.

    """
    name = 'GaussianKLD'

    def __init__(self, num_images, base='e', reduction='batchmean'):
        super().__init__(num_images, image_shape=None)
        assert reduction in [
            'sum', 'batchmean', 'mean'
        ], ('We only support reduction for \'batchmean\', \'sum\' '
            'and \'mean\'')
        assert base in ['e',
                        '2'], ('We only support log_base for \'e\' and \'2\'')
        self.reduction = reduction
        self.num_fake_feeded = self.num_images
        self.cal_kld = partial(
            gaussian_kld, weight=1, reduction='none', base=base)

    def prepare(self):
        """Prepare for evaluating models with this metric."""
        self.kld = []
        self.num_real_feeded = 0

    @torch.no_grad()
    def feed_op(self, batch, mode):
        """Feed data to the metric.

        Args:
            batch (Tensor): Input tensor.
            mode (str): The mode of current data batch. 'reals' or 'fakes'.
        """
        if mode == 'fakes':
            return
        assert isinstance(batch, dict), ('To calculate GaussianKLD loss, a '
                                         'dict contains probabilistic '
                                         'parameters is required.')
        # check required keys
        require_keys = [
            'mean_pred', 'mean_target', 'logvar_pred', 'logvar_target'
        ]
        if any([k not in batch for k in require_keys]):
            raise KeyError(f'The input dict must require {require_keys} at '
                           'the same time. But keys in the given dict are '
                           f'{batch.keys()}. Some of the requirements are '
                           'missing.')
        kld = self.cal_kld(batch['mean_target'], batch['mean_pred'],
                           batch['logvar_target'], batch['logvar_pred'])
        if dist.is_initialized():
            ws = dist.get_world_size()
            placeholder = [torch.zeros_like(kld) for _ in range(ws)]
            dist.all_gather(placeholder, kld)
            kld = torch.cat(placeholder, dim=0)

        # in distributed training, we only collect features at rank-0.
        if (dist.is_initialized()
                and dist.get_rank() == 0) or not dist.is_initialized():
            self.kld.append(kld.cpu())

    @torch.no_grad()
    def summary(self):
        """Summarize the results.

        Returns:
            dict | list: Summarized results.
        """
        kld = torch.cat(self.kld, dim=0)
        assert kld.shape[0] >= self.num_images
        kld_np = kld.numpy()
        if self.reduction == 'sum':
            kld_result = np.sum(kld_np)
        elif self.reduction == 'mean':
            kld_result = np.mean(kld_np)
        else:
            kld_result = np.sum(kld_np) / kld_np.shape[0]
        self._result_str = (f'{kld_result:.4f}')
        return kld_result