visualization.py 3.18 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image


@HOOKS.register_module('MMGenVisualizationHook')
class VisualizationHook(Hook):
    """Visualization hook.

    In this hook, we use the official api `save_image` in torchvision to save
    the visualization results.

    Args:
        output_dir (str): The file path to store visualizations.
        res_name_list (str): The list contains the name of results in outputs
            dict. The results in outputs dict must be a torch.Tensor with shape
            (n, c, h, w).
        interval (int): The interval of calling this hook. If set to -1,
            the visualization hook will not be called. Default: -1.
        filename_tmpl (str): Format string used to save images. The output file
            name will be formatted as this args. Default: 'iter_{}.png'.
        rerange (bool): Whether to rerange the output value from [-1, 1] to
            [0, 1]. We highly recommend users should preprocess the
            visualization results on their own. Here, we just provide a simple
            interface. Default: True.
        bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
            RGB. The final image we will save is following RGB style.
            Default: True.
        nrow (int): The number of samples in a row. Default: 1.
        padding (int): The number of padding pixels between each samples.
            Default: 4.
    """

    def __init__(self,
                 output_dir,
                 res_name_list,
                 interval=-1,
                 filename_tmpl='iter_{}.png',
                 rerange=True,
                 bgr2rgb=True,
                 nrow=1,
                 padding=4):
        assert mmcv.is_list_of(res_name_list, str)
        self.output_dir = output_dir
        self.res_name_list = res_name_list
        self.interval = interval
        self.filename_tmpl = filename_tmpl
        self.bgr2rgb = bgr2rgb
        self.rerange = rerange
        self.nrow = nrow
        self.padding = padding

    @master_only
    def after_train_iter(self, runner):
        """The behavior after each train iteration.

        Args:
            runner (object): The runner.
        """
        if not self.every_n_iters(runner, self.interval):
            return
        results = runner.outputs['results']

        filename = self.filename_tmpl.format(runner.iter + 1)

        # img_list = [x for k, x in results.items() if k in self.res_name_list]
        img_list = [results[k] for k in self.res_name_list if k in results]
        img_cat = torch.cat(img_list, dim=3).detach()
        if self.rerange:
            img_cat = ((img_cat + 1) / 2)
        if self.bgr2rgb:
            img_cat = img_cat[:, [2, 1, 0], ...]
        img_cat = img_cat.clamp_(0, 1)

        if not hasattr(self, '_out_dir'):
            self._out_dir = osp.join(runner.work_dir, self.output_dir)
        mmcv.mkdir_or_exist(self._out_dir)
        save_image(
            img_cat,
            osp.join(self._out_dir, filename),
            nrow=self.nrow,
            padding=self.padding)