ddd.py 4.26 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import cv2
import time
import torch

from external.nms import soft_nms
from models.decode import ddd_decode
from models.utils import flip_tensor
from utils.image import get_affine_transform
from utils.post_process import ddd_post_process
from utils.debugger import Debugger
from utils.ddd_utils import compute_box_3d, project_to_image, alpha2rot_y
from utils.ddd_utils import draw_box_3d, unproject_2d_to_3d
from progress.bar import Bar
from .base_detector import BaseDetector

import numpy as np


class DddDetector(BaseDetector):
    def __init__(self, opt):
        super(DddDetector, self).__init__(opt)
        self.calib = np.array([[707.0493, 0, 604.0814, 45.75831],
                               [0, 707.0493, 180.5066, -0.3454157],
                               [0, 0, 1., 0.004981016]], dtype=np.float32)

    def pre_process(self, image, scale, calib=None):
        height, width = image.shape[0:2]

        inp_height, inp_width = self.opt.input_h, self.opt.input_w
        c = np.array([width / 2, height / 2], dtype=np.float32)
        if self.opt.keep_res:
            s = np.array([inp_width, inp_height], dtype=np.int32)
        else:
            s = np.array([width, height], dtype=np.int32)

        trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
        resized_image = image  # cv2.resize(image, (width, height))
        inp_image = cv2.warpAffine(
            resized_image, trans_input, (inp_width, inp_height),
            flags=cv2.INTER_LINEAR)
        inp_image = (inp_image.astype(np.float32) / 255.)
        inp_image = (inp_image - self.mean) / self.std
        images = inp_image.transpose(2, 0, 1)[np.newaxis, ...]
        calib = np.array(calib, dtype=np.float32) if calib is not None \
            else self.calib
        images = torch.from_numpy(images)
        meta = {'c': c, 's': s,
                'out_height': inp_height // self.opt.down_ratio,
                'out_width': inp_width // self.opt.down_ratio,
                'calib': calib}
        return images, meta

    def process(self, images, return_time=False):
        with torch.no_grad():
            torch.cuda.synchronize()
            output = self.model(images)[-1]
            output['hm'] = output['hm'].sigmoid_()
            output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1.
            wh = output['wh'] if self.opt.reg_bbox else None
            reg = output['reg'] if self.opt.reg_offset else None
            torch.cuda.synchronize()
            forward_time = time.time()

            dets = ddd_decode(output['hm'], output['rot'], output['dep'],
                              output['dim'], wh=wh, reg=reg, K=self.opt.K)
        if return_time:
            return output, dets, forward_time
        else:
            return output, dets

    def post_process(self, dets, meta, scale=1):
        dets = dets.detach().cpu().numpy()
        detections = ddd_post_process(
            dets.copy(), [meta['c']], [meta['s']], [meta['calib']], self.opt)
        self.this_calib = meta['calib']
        return detections[0]

    def merge_outputs(self, detections):
        results = detections[0]
        for j in range(1, self.num_classes + 1):
            if len(results[j] > 0):
                keep_inds = (results[j][:, -1] > self.opt.peak_thresh)
                results[j] = results[j][keep_inds]
        return results

    def debug(self, debugger, images, dets, output, scale=1):
        dets = dets.detach().cpu().numpy()
        img = images[0].detach().cpu().numpy().transpose(1, 2, 0)
        img = ((img * self.std + self.mean) * 255).astype(np.uint8)
        pred = debugger.gen_colormap(output['hm'][0].detach().cpu().numpy())
        debugger.add_blend_img(img, pred, 'pred_hm')
        debugger.add_ct_detection(
            img, dets[0], show_box=self.opt.reg_bbox,
            center_thresh=self.opt.vis_thresh, img_id='det_pred')

    def show_results(self, debugger, image, results):
        debugger.add_3d_detection(
            image, results, self.this_calib,
            center_thresh=self.opt.vis_thresh, img_id='add_pred')
        debugger.add_bird_view(
            results, center_thresh=self.opt.vis_thresh, img_id='bird_pred')
        debugger.show_all_imgs(pause=self.pause)