viz.py 2.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

import visdom
import matplotlib.pyplot as PL
from util import *
import numpy as np
import cv2

def _fig_to_ndarray(fig):
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    #data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
    data = data.transpose(2, 0, 1)
    PL.close(fig)

    return data

class VisdomWindowManager(visdom.Visdom):
    def __init__(self, **kwargs):
        visdom.Visdom.__init__(self, **kwargs)
        self.scalar_plot_length = {}
        self.scalar_plot_prev_point = {}

        self.mpl_figure_sequence = {}

    def append_scalar(self, name, value, t=None, opts=None):
        if self.scalar_plot_length.get(name, 0) == 0:
            # If we are creating a scalar plot, store the starting point but
            # don't plot anything yet
            self.close(name)
            t = 0 if t is None else t
            self.scalar_plot_length[name] = 0
        else:
            # If we have at least two values, then plot a segment
            t = self.scalar_plot_length[name] if t is None else t
            prev_v, prev_t = self.scalar_plot_prev_point[name]
            newopts = {'xlabel': 'time', 'ylabel': name}
            if opts is not None:
                newopts.update(opts)
            self.line(
                    X=np.array([prev_t, t]),
                    Y=np.array([prev_v, value]),
                    win=name,
                    update=None if not self.win_exists(name) else 'append',
                    opts=newopts
                    )

        self.scalar_plot_prev_point[name] = (value, t)
        self.scalar_plot_length[name] += 1

    def display_mpl_figure(self, fig, **kwargs):
        '''
        Call this function before calling 'PL.show()' or 'PL.savefig()'.
        '''
        self.image(
                _fig_to_ndarray(fig),
                **kwargs
                )

    def reset_mpl_figure_sequence(self, name):
        self.mpl_figure_sequence[name] = []

    def append_mpl_figure_to_sequence(self, name, fig):
        data = _fig_to_ndarray(fig)
        data = data.transpose(1, 2, 0)
        if name not in self.mpl_figure_sequence:
            self.reset_mpl_figure_sequence(name)
        self.mpl_figure_sequence[name].append(data)

    def display_mpl_figure_sequence(self, name, **kwargs):
        data_seq = self.mpl_figure_sequence[name]
        video_rows, video_cols = data_seq[0].shape[:2]
        data_seq = [cv2.resize(f, (video_cols, video_rows)) for f in data_seq]
        data_seq = np.array(data_seq, dtype=np.uint8)

        self.video(
                data_seq,
                **kwargs
                )