model_utils.py 3.02 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
from os.path import dirname, exists, join

import numpy as np
import torch
from mmengine import InstanceData

from mmdet3d.core import Det3DDataSample, LiDARInstance3DBoxes


def _setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def _get_config_directory():
    """Find the predefined detector config directory."""
    try:
        # Assume we are running in the source mmdetection3d repo
        repo_dpath = dirname(dirname(dirname(__file__)))
    except NameError:
        # For IPython development when this __file__ is not defined
        import mmdet3d
        repo_dpath = dirname(dirname(mmdet3d.__file__))
    config_dpath = join(repo_dpath, 'configs')
    if not exists(config_dpath):
        raise Exception('Cannot find config path')
    return config_dpath


def _get_config_module(fname):
    """Load a configuration as a python module."""
    from mmcv import Config
    config_dpath = _get_config_directory()
    config_fpath = join(config_dpath, fname)
    config_mod = Config.fromfile(config_fpath)
    return config_mod


def _get_model_cfg(fname):
    """Grab configs necessary to create a model.

    These are deep copied to allow for safe modification of parameters without
    influencing other tests.
    """
    config = _get_config_module(fname)
    model = copy.deepcopy(config.model)

    return model


def _get_detector_cfg(fname):
    """Grab configs necessary to create a detector.

    These are deep copied to allow for safe modification of parameters without
    influencing other tests.
    """
    import mmcv
    config = _get_config_module(fname)
    model = copy.deepcopy(config.model)
    train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
    test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))

    model.update(train_cfg=train_cfg)
    model.update(test_cfg=test_cfg)
    return model


def _create_detector_inputs(seed=0,
jshilong's avatar
jshilong committed
75
76
                            with_points=True,
                            with_img=False,
jshilong's avatar
jshilong committed
77
78
                            num_gt_instance=20,
                            points_feat_dim=4,
79
                            gt_bboxes_dim=7,
jshilong's avatar
jshilong committed
80
81
                            num_classes=3):
    _setup_seed(seed)
jshilong's avatar
jshilong committed
82
83
84
85
86
87
88
89
    inputs_dict = dict()
    if with_points:
        points = torch.rand([3, points_feat_dim])
        inputs_dict['points'] = points
    if with_img:
        img = torch.rand(3, 10, 10)
        inputs_dict['img'] = img

jshilong's avatar
jshilong committed
90
91
    gt_instance_3d = InstanceData()
    gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes(
92
        torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
jshilong's avatar
jshilong committed
93
94
95
96
97
98
    gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
    data_sample = Det3DDataSample(
        metainfo=dict(box_type_3d=LiDARInstance3DBoxes))
    data_sample.gt_instances_3d = gt_instance_3d
    data_sample.seg_data = dict()
    return dict(inputs=inputs_dict, data_sample=data_sample)