visualizer.py 1.21 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn
import torch.nn.functional as F
import torchvision.utils as tvu


class Visualizer(object):
    
    def __init__(self, path):
        if not os.path.isdir(path):
            os.makedirs(path)
        self.writer = SummaryWriter(path)
        self.imgs_path = path.replace('events', 'Imgs')
        if not os.path.isdir(self.imgs_path):
            os.makedirs(self.imgs_path)
        
    def Scalar(self, name, val, step):
        self.writer.add_scalar(name, val, step)
        
    def Image(self, name, val, step):
        x = tvu.make_grid(val)
        self.writer.add_image(name, x, step)

    def save_image(self, val, val_deform, step, srcs, drv, num, src_ldmk_line=None, drv_ldmk_line=None, test_flag=-1):
        src = srcs[:num]
        imgs = torch.cat((drv[:num], drv_ldmk_line[:num], src, src_ldmk_line[:num], val_deform[:num], val[:num]), dim=0)

        if test_flag == -1:
            tvu.save_image(imgs, os.path.join(self.imgs_path, str(step)+'.png'), nrow=num, normalize=True)
        else:
            tvu.save_image(imgs, os.path.join(self.imgs_path, 'test_' + str(step) + '_' + str(test_flag) +'.png'), nrow=num, normalize=True)