".github/vscode:/vscode.git/clone" did not exist on "bc1534ff32c27c4c1a0f485b07031dde516489a4"
ddd.py 7.31 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import numpy as np

from models.losses import FocalLoss, L1Loss, BinRotLoss
from models.decode import ddd_decode
from models.utils import _sigmoid
from utils.debugger import Debugger
from utils.post_process import ddd_post_process
from utils.oracle_utils import gen_oracle_map
from .base_trainer import BaseTrainer


class DddLoss(torch.nn.Module):
    def __init__(self, opt):
        super(DddLoss, self).__init__()
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_reg = L1Loss()
        self.crit_rot = BinRotLoss()
        self.opt = opt

    def forward(self, outputs, batch):
        opt = self.opt

        hm_loss, dep_loss, rot_loss, dim_loss = 0, 0, 0, 0
        wh_loss, off_loss = 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1.

            if opt.eval_oracle_dep:
                output['dep'] = torch.from_numpy(gen_oracle_map(
                    batch['dep'].detach().cpu().numpy(),
                    batch['ind'].detach().cpu().numpy(),
                    opt.output_w, opt.output_h)).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.dep_weight > 0:
                dep_loss += self.crit_reg(output['dep'], batch['reg_mask'],
                                          batch['ind'], batch['dep']) / opt.num_stacks
            if opt.dim_weight > 0:
                dim_loss += self.crit_reg(output['dim'], batch['reg_mask'],
                                          batch['ind'], batch['dim']) / opt.num_stacks
            if opt.rot_weight > 0:
                rot_loss += self.crit_rot(output['rot'], batch['rot_mask'],
                                          batch['ind'], batch['rotbin'],
                                          batch['rotres']) / opt.num_stacks
            if opt.reg_bbox and opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['rot_mask'],
                                         batch['ind'], batch['wh']) / opt.num_stacks
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['rot_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \
            opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \
            opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss_stats = {'loss': loss, 'hm_loss': hm_loss, 'dep_loss': dep_loss,
                      'dim_loss': dim_loss, 'rot_loss': rot_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss}
        return loss, loss_stats


class DddTrainer(BaseTrainer):
    def __init__(self, opt, model, optimizer=None):
        super(DddTrainer, self).__init__(opt, model, optimizer=optimizer)

    def _get_losses(self, opt):
        loss_states = ['loss', 'hm_loss', 'dep_loss', 'dim_loss', 'rot_loss',
                       'wh_loss', 'off_loss']
        loss = DddLoss(opt)
        return loss_states, loss

    def debug(self, batch, output, iter_id):
        opt = self.opt
        wh = output['wh'] if opt.reg_bbox else None
        reg = output['reg'] if opt.reg_offset else None
        dets = ddd_decode(output['hm'], output['rot'], output['dep'],
                          output['dim'], wh=wh, reg=reg, K=opt.K)

        # x, y, score, r1-r8, depth, dim1-dim3, cls
        dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
        calib = batch['meta']['calib'].detach().numpy()
        # x, y, score, rot, depth, dim1, dim2, dim3
        # if opt.dataset == 'gta':
        #   dets[:, 12:15] /= 3
        dets_pred = ddd_post_process(
            dets.copy(), batch['meta']['c'].detach().numpy(),
            batch['meta']['s'].detach().numpy(), calib, opt)
        dets_gt = ddd_post_process(
            batch['meta']['gt_det'].detach().numpy().copy(),
            batch['meta']['c'].detach().numpy(),
            batch['meta']['s'].detach().numpy(), calib, opt)
        # for i in range(input.size(0)):
        for i in range(1):
            debugger = Debugger(dataset=opt.dataset, ipynb=(opt.debug == 3),
                                theme=opt.debugger_theme)
            img = batch['input'][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = ((img * self.opt.std + self.opt.mean) * 255.).astype(np.uint8)
            pred = debugger.gen_colormap(
                output['hm'][i].detach().cpu().numpy())
            gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy())
            debugger.add_blend_img(img, pred, 'hm_pred')
            debugger.add_blend_img(img, gt, 'hm_gt')
            # decode
            debugger.add_ct_detection(
                img, dets[i], show_box=opt.reg_bbox, center_thresh=opt.center_thresh,
                img_id='det_pred')
            debugger.add_ct_detection(
                img, batch['meta']['gt_det'][i].cpu().numpy().copy(),
                show_box=opt.reg_bbox, img_id='det_gt')
            debugger.add_3d_detection(
                batch['meta']['image_path'][i], dets_pred[i], calib[i],
                center_thresh=opt.center_thresh, img_id='add_pred')
            debugger.add_3d_detection(
                batch['meta']['image_path'][i], dets_gt[i], calib[i],
                center_thresh=opt.center_thresh, img_id='add_gt')
            # debugger.add_bird_view(
            #   dets_pred[i], center_thresh=opt.center_thresh, img_id='bird_pred')
            # debugger.add_bird_view(dets_gt[i], img_id='bird_gt')
            debugger.add_bird_views(
                dets_pred[i], dets_gt[i],
                center_thresh=opt.center_thresh, img_id='bird_pred_gt')

            # debugger.add_blend_img(img, pred, 'out', white=True)
            debugger.compose_vis_add(
                batch['meta']['image_path'][i], dets_pred[i], calib[i],
                opt.center_thresh, pred, 'bird_pred_gt', img_id='out')
            # debugger.add_img(img, img_id='out')
            if opt.debug == 4:
                debugger.save_all_imgs(
                    opt.debug_dir, prefix='{}'.format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)

    def save_result(self, output, batch, results):
        opt = self.opt
        wh = output['wh'] if opt.reg_bbox else None
        reg = output['reg'] if opt.reg_offset else None
        dets = ddd_decode(output['hm'], output['rot'], output['dep'],
                          output['dim'], wh=wh, reg=reg, K=opt.K)

        # x, y, score, r1-r8, depth, dim1-dim3, cls
        dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
        calib = batch['meta']['calib'].detach().numpy()
        # x, y, score, rot, depth, dim1, dim2, dim3
        dets_pred = ddd_post_process(
            dets.copy(), batch['meta']['c'].detach().numpy(),
            batch['meta']['s'].detach().numpy(), calib, opt)
        img_id = batch['meta']['img_id'].detach().numpy()[0]
        results[img_id] = dets_pred[0]
        for j in range(1, opt.num_classes + 1):
            keep_inds = (results[img_id][j][:, -1] > opt.center_thresh)
            results[img_id][j] = results[img_id][j][keep_inds]