model_utils.py 5.24 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
4
5
6
7
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
from os.path import dirname, exists, join

import numpy as np
import torch
8
from mmengine.structures import InstanceData
jshilong's avatar
jshilong committed
9

zhangshilong's avatar
zhangshilong committed
10
11
12
from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
                                Det3DDataSample, LiDARInstance3DBoxes,
                                PointData)
jshilong's avatar
jshilong committed
13
14


15
def setup_seed(seed):
jshilong's avatar
jshilong committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def _get_config_directory():
    """Find the predefined detector config directory."""
    try:
        # Assume we are running in the source mmdetection3d repo
        repo_dpath = dirname(dirname(dirname(__file__)))
    except NameError:
        # For IPython development when this __file__ is not defined
        import mmdet3d
        repo_dpath = dirname(dirname(mmdet3d.__file__))
    config_dpath = join(repo_dpath, 'configs')
    if not exists(config_dpath):
        raise Exception('Cannot find config path')
    return config_dpath


def _get_config_module(fname):
    """Load a configuration as a python module."""
40
    from mmengine import Config
jshilong's avatar
jshilong committed
41
42
43
44
45
46
    config_dpath = _get_config_directory()
    config_fpath = join(config_dpath, fname)
    config_mod = Config.fromfile(config_fpath)
    return config_mod


47
def get_model_cfg(fname):
jshilong's avatar
jshilong committed
48
49
50
51
52
53
54
55
56
57
58
    """Grab configs necessary to create a model.

    These are deep copied to allow for safe modification of parameters without
    influencing other tests.
    """
    config = _get_config_module(fname)
    model = copy.deepcopy(config.model)

    return model


59
def get_detector_cfg(fname):
jshilong's avatar
jshilong committed
60
61
62
63
64
    """Grab configs necessary to create a detector.

    These are deep copied to allow for safe modification of parameters without
    influencing other tests.
    """
65
    import mmengine
jshilong's avatar
jshilong committed
66
67
    config = _get_config_module(fname)
    model = copy.deepcopy(config.model)
68
69
    train_cfg = mmengine.Config(copy.deepcopy(config.model.train_cfg))
    test_cfg = mmengine.Config(copy.deepcopy(config.model.test_cfg))
jshilong's avatar
jshilong committed
70
71
72
73
74
75

    model.update(train_cfg=train_cfg)
    model.update(test_cfg=test_cfg)
    return model


76
77
78
79
80
81
82
83
84
85
86
87
88
def create_detector_inputs(seed=0,
                           with_points=True,
                           with_img=False,
                           img_size=10,
                           num_gt_instance=20,
                           num_points=10,
                           points_feat_dim=4,
                           num_classes=3,
                           gt_bboxes_dim=7,
                           with_pts_semantic_mask=False,
                           with_pts_instance_mask=False,
                           bboxes_3d_type='lidar'):
    setup_seed(seed)
jshilong's avatar
jshilong committed
89
90
91
92
93
94
    assert bboxes_3d_type in ('lidar', 'depth', 'cam')
    bbox_3d_class = {
        'lidar': LiDARInstance3DBoxes,
        'depth': DepthInstance3DBoxes,
        'cam': CameraInstance3DBoxes
    }
zhangshilong's avatar
zhangshilong committed
95
96
97
98
99
100
101
102
103
    meta_info = dict()
    meta_info['depth2img'] = np.array(
        [[5.23289349e+02, 3.68831943e+02, 6.10469439e+01],
         [1.09560138e+02, 1.97404735e+02, -5.47377738e+02],
         [1.25930002e-02, 9.92229998e-01, -1.23769999e-01]])
    meta_info['lidar2img'] = np.array(
        [[5.23289349e+02, 3.68831943e+02, 6.10469439e+01],
         [1.09560138e+02, 1.97404735e+02, -5.47377738e+02],
         [1.25930002e-02, 9.92229998e-01, -1.23769999e-01]])
104
105
106

    inputs_dict = dict()

jshilong's avatar
jshilong committed
107
    if with_points:
jshilong's avatar
jshilong committed
108
        points = torch.rand([num_points, points_feat_dim])
109
110
        inputs_dict['points'] = [points]

jshilong's avatar
jshilong committed
111
    if with_img:
Tai-Wang's avatar
Tai-Wang committed
112
113
114
115
116
117
118
119
        if isinstance(img_size, tuple):
            img = torch.rand(3, img_size[0], img_size[1])
            meta_info['img_shape'] = img_size
            meta_info['ori_shape'] = img_size
        else:
            img = torch.rand(3, img_size, img_size)
            meta_info['img_shape'] = (img_size, img_size)
            meta_info['ori_shape'] = (img_size, img_size)
zhangshilong's avatar
zhangshilong committed
120
        meta_info['scale_factor'] = np.array([1., 1.])
121
        inputs_dict['img'] = [img]
zhangshilong's avatar
zhangshilong committed
122

jshilong's avatar
jshilong committed
123
    gt_instance_3d = InstanceData()
zhangshilong's avatar
zhangshilong committed
124

jshilong's avatar
jshilong committed
125
    gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
126
        torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
jshilong's avatar
jshilong committed
127
128
    gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
    data_sample = Det3DDataSample(
jshilong's avatar
jshilong committed
129
        metainfo=dict(box_type_3d=bbox_3d_class[bboxes_3d_type]))
zhangshilong's avatar
zhangshilong committed
130
    data_sample.set_metainfo(meta_info)
jshilong's avatar
jshilong committed
131
    data_sample.gt_instances_3d = gt_instance_3d
zhangshilong's avatar
zhangshilong committed
132
133
134
135

    gt_instance = InstanceData()
    gt_instance.labels = torch.randint(0, num_classes, [num_gt_instance])
    gt_instance.bboxes = torch.rand(num_gt_instance, 4)
zhangshilong's avatar
zhangshilong committed
136
137
138
    gt_instance.bboxes[:,
                       2:] = gt_instance.bboxes[:, :2] + gt_instance.bboxes[:,
                                                                            2:]
zhangshilong's avatar
zhangshilong committed
139
140

    data_sample.gt_instances = gt_instance
jshilong's avatar
jshilong committed
141
142
143
144
145
146
147
148
    data_sample.gt_pts_seg = PointData()
    if with_pts_instance_mask:
        pts_instance_mask = torch.randint(0, num_gt_instance, [num_points])
        data_sample.gt_pts_seg['pts_instance_mask'] = pts_instance_mask
    if with_pts_semantic_mask:
        pts_semantic_mask = torch.randint(0, num_classes, [num_points])
        data_sample.gt_pts_seg['pts_semantic_mask'] = pts_semantic_mask

149
    return dict(inputs=inputs_dict, data_samples=[data_sample])