__init__.py 1.52 KB
Newer Older
1
from collections import namedtuple
Shaoshuai Shi's avatar
Shaoshuai Shi committed
2
3
4
5

import numpy as np
import torch

6
7
from .detectors import build_detector

djiajunustc's avatar
djiajunustc committed
8
9
10
11
12
13
14
try:
    import kornia
except:
    pass 
    # print('Warning: kornia is not installed. This package is only required by CaDDN')


15
16
17
18
19
20
21
22

def build_network(model_cfg, num_class, dataset):
    model = build_detector(
        model_cfg=model_cfg, num_class=num_class, dataset=dataset
    )
    return model


23
24
def load_data_to_gpu(batch_dict):
    for key, val in batch_dict.items():
25
26
27
        if key == 'camera_imgs':
            batch_dict[key] = val.cuda()
        elif not isinstance(val, np.ndarray):
28
            continue
29
        elif key in ['frame_id', 'metadata', 'calib', 'image_paths','ori_shape','img_process_infos']:
30
            continue
31
        elif key in ['images']:
djiajunustc's avatar
djiajunustc committed
32
            batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous()
33
34
35
36
        elif key in ['image_shape']:
            batch_dict[key] = torch.from_numpy(val).int().cuda()
        else:
            batch_dict[key] = torch.from_numpy(val).float().cuda()
37
38


39
40
41
42
def model_fn_decorator():
    ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])

    def model_func(model, batch_dict):
43
        load_data_to_gpu(batch_dict)
44
45
46
47
48
49
50
51
52
53
54
        ret_dict, tb_dict, disp_dict = model(batch_dict)

        loss = ret_dict['loss'].mean()
        if hasattr(model, 'update_global_step'):
            model.update_global_step()
        else:
            model.module.update_global_step()

        return ModelReturn(loss, tb_dict, disp_dict)

    return model_func