visualizers.py 12.8 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import numpy as np
import numpy.lib.recfunctions as nlr
import cv2 as cv
import colorsys
from skimage.measure import block_reduce
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from ..representations.image import events_to_image, TimestampImage
from ..representations.voxel_grid import events_to_voxel
from ..util.event_util import clip_events_to_bounds
from .visualization_utils import *
from tqdm import tqdm

class Visualizer():

    def __init__(self):
        raise NotImplementedError

    def plot_events(self, data, save_path, **kwargs):
        raise NotImplementedError

    @staticmethod
    def unpackage_events(events):
        return events[:,0].astype(int), events[:,1].astype(int), events[:,2], events[:,3]

class TimeStampImageVisualizer(Visualizer):

    def __init__(self, sensor_size):
        self.ts_img = TimestampImage(sensor_size)
        self.sensor_size = sensor_size

    def plot_events(self, data, save_path, **kwargs):
        xs, ys, ts, ps = self.unpackage_events(data['events'])
        self.ts_img.set_init(ts[0])
        self.ts_img.add_events(xs, ys, ts, ps)
        timestamp_image = self.ts_img.get_image()
        fig = plt.figure()
        plt.imshow(timestamp_image, cmap='viridis')
        ensure_dir(save_path)
        plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight')
        #plt.show()

class EventImageVisualizer(Visualizer):

    def __init__(self, sensor_size):
        self.sensor_size = sensor_size

    def plot_events(self, data, save_path, **kwargs):
        xs, ys, ts, ps = self.unpackage_events(data['events'])
        img = events_to_image(xs.astype(int), ys.astype(int), ps, self.sensor_size, interpolation=None, padding=False)
        mn, mx = np.min(img), np.max(img)
        img = (img-mn)/(mx-mn)

        fig = plt.figure()
        plt.imshow(img, cmap='gray')
        ensure_dir(save_path)
        plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight')
        #plt.show()


class EventsVisualizer(Visualizer):

    def __init__(self, sensor_size):
        self.sensor_size = sensor_size

    def plot_events(self, data, save_path,
            num_compress='auto', num_show=1000,
            event_size=2, elev=0, azim=45, show_events=True,
            show_frames=True, show_plot=False, crop=None, compress_front=False,
            marker='.', stride = 1, invert=False, show_axes=False, flip_x=False):
        """
        Given events, plot these in a spatiotemporal volume.
        :param: xs x coords of events
        :param: ys y coords of events
        :param: ts t coords of events
        :param: ps p coords of events
        :param: save_path if set, will save plot to here
        :param: num_compress will take this number of events from the end
            and create an event image from these. This event image will
            be displayed at the end of the spatiotemporal volume
        :param: num_show sets the number of events to plot. If set to -1
            will plot all of the events (can be potentially expensive)
        :param: event_size sets the size of the plotted events
        :param: elev sets the elevation of the plot
        :param: azim sets the azimuth of the plot
        :param: imgs a list of images to draw into the spatiotemporal volume
        :param: img_ts a list of the position on the temporal axis where each
            image from 'imgs' is to be placed (the timestamp of the images, usually)
        :param: show_events if False, will not plot the events (only images)
        :param: crop a list of length 4 that sets the crop of the plot (must
            be in the format [top_left_y, top_left_x, height, width]
        """
        xs, ys, ts, ps = self.unpackage_events(data['events'])
        imgs, img_ts = data['frame'], data['frame_ts']
        if not (isinstance(imgs, list) or isinstance(imgs, tuple)):
            imgs, img_ts = [imgs], [img_ts]

        ys = self.sensor_size[0]-ys
        xs = self.sensor_size[1]-xs if flip_x else xs
        #Crop events
        img_size = self.sensor_size
        if img_size is None:
            img_size = [max(ys), max(ps)] if len(imgs)==0 else imgs[0].shape[0:2]
        crop = [0, img_size[0], 0, img_size[1]] if crop is None else crop
        xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop, set_zero=False)
        xs, ys = xs-crop[2], ys-crop[0]

        if len(xs) < 2:
            xs = np.array([0,0])
            ys = np.array([0,0])
            if img_ts is None:
                ts = np.array([0,0])
            else:
                ts = np.array([img_ts[0], img_ts[0]+0.000001])
            ps = np.array([0.,0.])

        #Defaults and range checks
        num_show = len(xs) if num_show == -1 else num_show
        skip = max(len(xs)//num_show, 1)
        num_compress = len(xs) if num_compress == 'all' else num_compress
        num_compress = min(int(img_size[0]*img_size[1]*0.5), len(xs)) if num_compress=='auto' else 0
        xs, ys, ts, ps = xs[::skip], ys[::skip], ts[::skip], ps[::skip]

        #Prepare the plot, set colors
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d', proj_type = 'ortho')
        colors = ['r' if p>0 else ('#00DAFF' if invert else 'b') for p in ps]

        #Plot images
        if len(imgs)>0 and show_frames:
            for imgidx, (img, img_ts) in enumerate(zip(imgs, img_ts)):
                img = img[crop[0]:crop[1], crop[2]:crop[3]].astype(float)
                img = np.flip(img, axis=0)
                img = np.flip(img, axis=1) if flip_x else img
                if len(img.shape)==2:
                    img = np.stack((img, img, img), axis=2)
                if num_compress > 0:
                    events_img = events_to_image(xs[0:num_compress], ys[0:num_compress],
                            np.ones(min(num_compress, len(xs))), sensor_size=img.shape[0:2])
                    events_img[events_img>0] = 1
                    img[:,:,1] += events_img[:,:]
                    img = np.clip(img, 0, 1)
                x, y = np.ogrid[0:img.shape[0], 0:img.shape[1]]
                event_idx = np.searchsorted(ts, img_ts)

                ax.scatter(xs[0:event_idx], ts[0:event_idx], ys[0:event_idx], zdir='z',
                        c=colors[0:event_idx], facecolors=colors[0:event_idx],
                        s=event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0)

                img /= 255.0
                #img = cv.normalize(img, None, 0, 1, cv.NORM_MINMAX)
                ax.plot_surface(y, img_ts, x, rstride=stride, cstride=stride, facecolors=img, alpha=1)

                ax.scatter(xs[event_idx:-1], ts[event_idx:-1], ys[event_idx:-1], zdir='z',
                        c=colors[event_idx:-1], facecolors=colors[event_idx:-1],
                        s=event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0)
    
        elif num_compress > 0:
            # Plot events
            ax.scatter(xs[::skip], ts[::skip], ys[::skip], zdir='z', c=colors[::skip], facecolors=colors[::skip],
                    s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0)
            num_compress = min(num_compress, len(xs))
            if not compress_front:
                ax.scatter(xs[0:num_compress], np.ones(num_compress)*ts[0], ys[0:num_compress],
                        marker=marker, zdir='z', c='w' if invert else 'k', s=np.ones(num_compress)*event_size)
            else:
                ax.scatter(xs[-num_compress-1:-1], np.ones(num_compress)*ts[-1], ys[-num_compress-1:-1],
                        marker=marker, zdir='z', c='w' if invert else 'k', s=np.ones(num_compress)*event_size)
        else:
            # Plot events
            ax.scatter(xs, ts, ys,zdir='z', c=colors, facecolors=colors, s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0)
    
        ax.view_init(elev=elev, azim=azim)
        ax.grid(False)
        # Hide panes
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        if not show_axes:
            # Hide spines
            ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.set_frame_on(False)
        # Hide xy axes
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        # Flush axes
        ax.set_xlim3d(0, img_size[1])
        ax.set_ylim3d(ts[0], ts[-1])
        ax.set_zlim3d(0,img_size[0])
        #ax.xaxis.set_visible(False)
        #ax.axes.get_yaxis().set_visible(False)

        if show_plot:
            plt.show()
        if save_path is not None:
            ensure_dir(save_path)
            print("Saving to {}".format(save_path))
            plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight')
        plt.close()

class VoxelVisualizer(Visualizer):

    def __init__(self, sensor_size):
        self.sensor_size = sensor_size

    @staticmethod
    def increase_brightness(rgb, increase=0.5):
        rgb = (rgb*255).astype('uint8')
        channels = rgb.shape[1]
        hsv = (np.stack([cv.cvtColor(rgb[:,x,:,:], cv.COLOR_RGB2HSV) for x in range(channels)])).astype(float)
        hsv[:,:,:,2] = np.clip(hsv[:,:,:,2] + increase*255, 0, 255)
        hsv = hsv.astype('uint8')
        rgb_new = np.stack([cv.cvtColor(hsv[x,:,:,:], cv.COLOR_HSV2RGB) for x in range(channels)])
        rgb_new = (rgb_new.transpose(1,0,2,3)).astype(float)
        return rgb_new/255.0

    def plot_events(self, data, save_path, bins=5, crop=None, elev=0, azim=45, show_axes=False,
            show_plot=False, flip_x=False, size_reduction=10):

        xs, ys, ts, ps = self.unpackage_events(data['events'])
        if len(xs) < 2:
            return
        ys = self.sensor_size[0]-ys
        xs = self.sensor_size[1]-xs if flip_x else xs

        frames, frame_ts = data['frame'], data['frame_ts']
        if not isinstance(frames, list):
            frames, frame_ts = [frames], [frame_ts]

        if self.sensor_size is None:
            self.sensor_size = [np.max(ys)+1, np.max(xs)+1] if len(frames)==0 else frames[0].shape
        if crop is not None:
            xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop)
            self.sensor_size = crop_to_size(crop)
            xs, ys = xs-crop[2], ys-crop[0]
        num = 10000
        xs, ys, ts, ps = xs[0:num], ys[0:num], ts[0:num], ps[0:num]
        if len(xs) == 0:
            return
        voxels = events_to_voxel(xs, ys, ts, ps, bins, sensor_size=self.sensor_size)
        voxels = block_reduce(voxels, block_size=(1,size_reduction,size_reduction), func=np.mean, cval=0)
        dimdiff = voxels.shape[1]-voxels.shape[0]
        filler = np.zeros((dimdiff, *voxels.shape[1:]))
        voxels = np.concatenate((filler, voxels), axis=0)
        voxels = voxels.transpose(0,2,1)

        pltvoxels = voxels != 0
        pvp, nvp = voxels > 0, voxels < 0
        rng = 0.2
        min_r, min_b, max_g = 80/255.0, 80/255.0, 0/255.0

        vox_cols = voxels/(max(np.abs(np.max(voxels)), np.abs(np.min(voxels))))
        pvox, nvox = vox_cols*np.where(vox_cols > 0, 1, 0), np.abs(vox_cols)*np.where(vox_cols < 0, 1, 0)
        pvox, nvox = pvox*(1-min_r)+min_r, nvox*(1-min_b)+min_b
        zeros = np.zeros_like(voxels)

        colors = np.empty(voxels.shape, dtype=object)

        increase = 0.5
        redvals = np.stack((pvox, (1.0-pvox)*max_g, pvox-min_r), axis=3)
        redvals = self.increase_brightness(redvals, increase=increase)
        redvals = nlr.unstructured_to_structured(redvals).astype('O')

        bluvals = np.stack((nvox-min_b, (1.0-nvox)*max_g, nvox), axis=3)
        bluvals = self.increase_brightness(bluvals, increase=increase)
        bluvals = nlr.unstructured_to_structured(bluvals).astype('O')

        colors[pvp] = redvals[pvp]
        colors[nvp] = bluvals[nvp]

        fig = plt.figure()
        ax = fig.gca(projection='3d')
        ax.voxels(pltvoxels, facecolors=colors)
        ax.view_init(elev=elev, azim=azim)

        ax.grid(False)
        # Hide panes
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        if not show_axes:
            # Hide spines
            ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.set_frame_on(False)
        # Hide xy axes
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])

        ax.xaxis.set_visible(False)
        ax.axes.get_yaxis().set_visible(False)

        if show_plot:
            plt.show()
        if save_path is not None:
            ensure_dir(save_path)
            print("Saving to {}".format(save_path))
            plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight')
        plt.close()