exdet.py 3.98 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import numpy as np
import cv2
import sys
import time
from utils.debugger import Debugger
from models.data_parallel import DataParallel
from models.losses import FocalLoss, RegL1Loss
from models.decode import agnex_ct_decode, exct_decode
from models.utils import _sigmoid
from .base_trainer import BaseTrainer


class ExdetLoss(torch.nn.Module):
    def __init__(self, opt):
        super(ExdetLoss, self).__init__()
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_reg = RegL1Loss()
        self.opt = opt
        self.parts = ['t', 'l', 'b', 'r', 'c']

    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, reg_loss = 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            for p in self.parts:
                tag = 'hm_{}'.format(p)
                output[tag] = _sigmoid(output[tag])
                hm_loss += self.crit(output[tag], batch[tag]) / opt.num_stacks
                if p != 'c' and opt.reg_offset and opt.off_weight > 0:
                    reg_loss += self.crit_reg(output['reg_{}'.format(p)],
                                              batch['reg_mask'],
                                              batch['ind_{}'.format(p)],
                                              batch['reg_{}'.format(p)]) / opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.off_weight * reg_loss
        loss_stats = {'loss': loss, 'off_loss': reg_loss, 'hm_loss': hm_loss}
        return loss, loss_stats


class ExdetTrainer(BaseTrainer):
    def __init__(self, opt, model, optimizer=None):
        super(ExdetTrainer, self).__init__(opt, model, optimizer=optimizer)
        self.decode = agnex_ct_decode if opt.agnostic_ex else exct_decode

    def _get_losses(self, opt):
        loss_states = ['loss', 'hm_loss', 'off_loss']
        loss = ExdetLoss(opt)
        return loss_states, loss

    def debug(self, batch, output, iter_id):
        opt = self.opt
        detections = self.decode(output['hm_t'], output['hm_l'],
                                 output['hm_b'], output['hm_r'],
                                 output['hm_c']).detach().cpu().numpy()
        detections[:, :, :4] *= opt.input_res / opt.output_res
        for i in range(1):
            debugger = Debugger(
                dataset=opt.dataset, ipynb=(opt.debug == 3), theme=opt.debugger_theme)
            pred_hm = np.zeros(
                (opt.input_res, opt.input_res, 3), dtype=np.uint8)
            gt_hm = np.zeros((opt.input_res, opt.input_res, 3), dtype=np.uint8)
            img = batch['input'][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = ((img * self.opt.std + self.opt.mean) * 255.).astype(np.uint8)
            for p in self.parts:
                tag = 'hm_{}'.format(p)
                pred = debugger.gen_colormap(
                    output[tag][i].detach().cpu().numpy())
                gt = debugger.gen_colormap(
                    batch[tag][i].detach().cpu().numpy())
                if p != 'c':
                    pred_hm = np.maximum(pred_hm, pred)
                    gt_hm = np.maximum(gt_hm, gt)
                if p == 'c' or opt.debug > 2:
                    debugger.add_blend_img(img, pred, 'pred_{}'.format(p))
                    debugger.add_blend_img(img, gt, 'gt_{}'.format(p))
            debugger.add_blend_img(img, pred_hm, 'pred')
            debugger.add_blend_img(img, gt_hm, 'gt')
            debugger.add_img(img, img_id='out')
            for k in range(len(detections[i])):
                if detections[i, k, 4] > 0.1:
                    debugger.add_coco_bbox(detections[i, k, :4], detections[i, k, -1],
                                           detections[i, k, 4], img_id='out')
            if opt.debug == 4:
                debugger.save_all_imgs(
                    opt.debug_dir, prefix='{}'.format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)