exdet.py 5.62 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths

import os
import cv2
import time
import torch

from external.nms import soft_nms
from models.decode import exct_decode, agnex_ct_decode
from models.utils import flip_tensor
from utils.image import get_affine_transform, transform_preds
from utils.post_process import ctdet_post_process
from utils.debugger import Debugger
from progress.bar import Bar
from .base_detector import BaseDetector

import numpy as np


class ExdetDetector(BaseDetector):
    def __init__(self, opt):
        super(ExdetDetector, self).__init__(opt)
        self.decode = agnex_ct_decode if opt.agnostic_ex else exct_decode

    def process(self, images, return_time=False):
        with torch.no_grad():
            torch.cuda.synchronize()
            output = self.model(images)[-1]
            t_heat = output['hm_t'].sigmoid_()
            l_heat = output['hm_l'].sigmoid_()
            b_heat = output['hm_b'].sigmoid_()
            r_heat = output['hm_r'].sigmoid_()
            c_heat = output['hm_c'].sigmoid_()
            torch.cuda.synchronize()
            forward_time = time.time()
            if self.opt.reg_offset:
                dets = self.decode(t_heat, l_heat, b_heat, r_heat, c_heat,
                                   output['reg_t'], output['reg_l'],
                                   output['reg_b'], output['reg_r'],
                                   K=self.opt.K,
                                   scores_thresh=self.opt.scores_thresh,
                                   center_thresh=self.opt.center_thresh,
                                   aggr_weight=self.opt.aggr_weight)
            else:
                dets = self.decode(t_heat, l_heat, b_heat, r_heat, c_heat, K=self.opt.K,
                                   scores_thresh=self.opt.scores_thresh,
                                   center_thresh=self.opt.center_thresh,
                                   aggr_weight=self.opt.aggr_weight)
        if return_time:
            return output, dets, forward_time
        else:
            return output, dets

    def debug(self, debugger, images, dets, output, scale=1):
        detection = dets.detach().cpu().numpy().copy()
        detection[:, :, :4] *= self.opt.down_ratio
        for i in range(1):
            inp_height, inp_width = images.shape[2], images.shape[3]
            pred_hm = np.zeros((inp_height, inp_width, 3), dtype=np.uint8)
            img = images[i].detach().cpu().numpy().transpose(1, 2, 0)
            img = ((img * self.std + self.mean) * 255).astype(np.uint8)
            parts = ['t', 'l', 'b', 'r', 'c']
            for p in parts:
                tag = 'hm_{}'.format(p)
                pred = debugger.gen_colormap(
                    output[tag][i].detach().cpu().numpy(), (inp_height, inp_width))
                if p != 'c':
                    pred_hm = np.maximum(pred_hm, pred)
                else:
                    debugger.add_blend_img(
                        img, pred, 'pred_{}_{:.1f}'.format(p, scale))
            debugger.add_blend_img(img, pred_hm, 'pred_{:.1f}'.format(scale))
            debugger.add_img(img, img_id='out_{:.1f}'.format(scale))
            for k in range(len(detection[i])):
                # print('detection', detection[i, k, 4], detection[i, k])
                if detection[i, k, 4] > 0.01:
                    # print('detection', detection[i, k, 4], detection[i, k])
                    debugger.add_coco_bbox(detection[i, k, :4], detection[i, k, -1],
                                           detection[i, k, 4],
                                           img_id='out_{:.1f}'.format(scale))

    def post_process(self, dets, meta, scale=1):
        out_width, out_height = meta['out_width'], meta['out_height']
        dets = dets.detach().cpu().numpy().reshape(2, -1, 14)
        dets[1, :, [0, 2]] = out_width - dets[1, :, [2, 0]]
        dets = dets.reshape(1, -1, 14)
        dets[0, :, 0:2] = transform_preds(
            dets[0, :, 0:2], meta['c'], meta['s'], (out_width, out_height))
        dets[0, :, 2:4] = transform_preds(
            dets[0, :, 2:4], meta['c'], meta['s'], (out_width, out_height))
        dets[:, :, 0:4] /= scale
        return dets[0]

    def merge_outputs(self, detections):
        detections = np.concatenate(
            [detection for detection in detections], axis=0).astype(np.float32)
        classes = detections[..., -1]
        keep_inds = (detections[:, 4] > 0)
        detections = detections[keep_inds]
        classes = classes[keep_inds]

        results = {}
        for j in range(self.num_classes):
            keep_inds = (classes == j)
            results[j + 1] = detections[keep_inds][:, 0:7].astype(np.float32)
            soft_nms(results[j + 1], Nt=0.5, method=2)
            results[j + 1] = results[j + 1][:, 0:5]

        scores = np.hstack([
            results[j][:, -1]
            for j in range(1, self.num_classes + 1)
        ])
        if len(scores) > self.max_per_image:
            kth = len(scores) - self.max_per_image
            thresh = np.partition(scores, kth)[kth]
            for j in range(1, self.num_classes + 1):
                keep_inds = (results[j][:, -1] >= thresh)
                results[j] = results[j][keep_inds]
        return results

    def show_results(self, debugger, image, results):
        debugger.add_img(image, img_id='exdet')
        for j in range(1, self.num_classes + 1):
            for bbox in results[j]:
                if bbox[4] > self.opt.vis_thresh:
                    debugger.add_coco_bbox(
                        bbox[:4], j - 1, bbox[4], img_id='exdet')
        debugger.show_all_imgs(pause=self.pause)