visualization.py 4.85 KB
Newer Older
Ponku's avatar
Ponku committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from typing import List

import numpy as np
import torch
from torch import Tensor
from torchvision.utils import make_grid


@torch.no_grad()
def make_disparity_image(disparity: Tensor):
    # normalize image to [0, 1]
    disparity = disparity.detach().cpu()
    disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min())
    return disparity


@torch.no_grad()
def make_disparity_image_pairs(disparity: Tensor, image: Tensor):
    disparity = make_disparity_image(disparity)
    # image is in [-1, 1], bring it to [0, 1]
    image = image.detach().cpu()
    image = image * 0.5 + 0.5
    return disparity, image


@torch.no_grad()
def make_disparity_sequence(disparities: List[Tensor]):
    # convert each disparity to [0, 1]
    for idx, disparity_batch in enumerate(disparities):
        disparities[idx] = torch.stack(list(map(make_disparity_image, disparity_batch)))
    # make the list into a batch
    disparity_sequences = torch.stack(disparities)
    return disparity_sequences


@torch.no_grad()
def make_pair_grid(*inputs, orientation="horizontal"):
    # make a grid of images with the outputs and references side by side
    if orientation == "horizontal":
        # interleave the outputs and references
        canvas = torch.zeros_like(inputs[0])
        canvas = torch.cat([canvas] * len(inputs), dim=0)
        size = len(inputs)
        for idx, inp in enumerate(inputs):
            canvas[idx::size, ...] = inp
        grid = make_grid(canvas, nrow=len(inputs), padding=16, normalize=True, scale_each=True)
    elif orientation == "vertical":
        # interleave the outputs and references
        canvas = torch.cat(inputs, dim=0)
        size = len(inputs)
        for idx, inp in enumerate(inputs):
            canvas[idx::size, ...] = inp
        grid = make_grid(canvas, nrow=len(inputs[0]), padding=16, normalize=True, scale_each=True)
    else:
        raise ValueError("Unknown orientation: {}".format(orientation))
    return grid


@torch.no_grad()
def make_training_sample_grid(
    left_images: Tensor,
    right_images: Tensor,
    disparities: Tensor,
    masks: Tensor,
    predictions: List[Tensor],
) -> np.ndarray:
    # detach images and renormalize to [0, 1]
    images_left = left_images.detach().cpu() * 0.5 + 0.5
    images_right = right_images.detach().cpu() * 0.5 + 0.5
    # detach the disparties and predictions
    disparities = disparities.detach().cpu()
    predictions = predictions[-1].detach().cpu()
    # keep only the first channel of pixels, and repeat it 3 times
    disparities = disparities[:, :1, ...].repeat(1, 3, 1, 1)
    predictions = predictions[:, :1, ...].repeat(1, 3, 1, 1)
    # unsqueeze and repeat the masks
    masks = masks.detach().cpu().unsqueeze(1).repeat(1, 3, 1, 1)
    # make a grid that will self normalize across the batch
    pred_grid = make_pair_grid(images_left, images_right, masks, disparities, predictions, orientation="horizontal")
    pred_grid = pred_grid.permute(1, 2, 0).numpy()
    pred_grid = (pred_grid * 255).astype(np.uint8)
    return pred_grid


@torch.no_grad()
def make_disparity_sequence_grid(predictions: List[Tensor], disparities: Tensor) -> np.ndarray:
    # right most we will be adding the ground truth
    seq_len = len(predictions) + 1
    predictions = list(map(lambda x: x[:, :1, :, :].detach().cpu(), predictions + [disparities]))
    sequence = make_disparity_sequence(predictions)
    # swap axes to have the in the correct order for each batch sample
    sequence = torch.swapaxes(sequence, 0, 1).contiguous().reshape(-1, 1, disparities.shape[-2], disparities.shape[-1])
    sequence = make_grid(sequence, nrow=seq_len, padding=16, normalize=True, scale_each=True)
    sequence = sequence.permute(1, 2, 0).numpy()
    sequence = (sequence * 255).astype(np.uint8)
    return sequence


@torch.no_grad()
def make_prediction_image_side_to_side(
    predictions: Tensor, disparities: Tensor, valid_mask: Tensor, save_path: str, prefix: str
) -> None:
    import matplotlib.pyplot as plt

    # normalize the predictions and disparities in [0, 1]
    predictions = (predictions - predictions.min()) / (predictions.max() - predictions.min())
    disparities = (disparities - disparities.min()) / (disparities.max() - disparities.min())
    predictions = predictions * valid_mask
    disparities = disparities * valid_mask

    predictions = predictions.detach().cpu()
    disparities = disparities.detach().cpu()

    for idx, (pred, gt) in enumerate(zip(predictions, disparities)):
        pred = pred.permute(1, 2, 0).numpy()
        gt = gt.permute(1, 2, 0).numpy()
        # plot pred and gt side by side
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(pred)
        ax[0].set_title("Prediction")
        ax[1].imshow(gt)
        ax[1].set_title("Ground Truth")
        save_name = os.path.join(save_path, "{}_{}.png".format(prefix, idx))
        plt.savefig(save_name)
        plt.close()