test_loading.py 2.77 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
from mmengine.testing import assert_allclose
from utils import create_dummy_data_info

from mmdet3d.core import DepthPoints, LiDARPoints
from mmdet3d.datasets.pipelines.loading import (LoadAnnotations3D,
                                                LoadPointsFromFile)


class TestLoadPointsFromFile(unittest.TestCase):

    def test_load_points_from_file(self):
        use_dim = 3
        file_client_args = dict(backend='disk')
        load_points_transform = LoadPointsFromFile(
            coord_type='LIDAR',
            load_dim=4,
            use_dim=use_dim,
            file_client_args=file_client_args)
        data_info = create_dummy_data_info()
        info = load_points_transform(data_info)
        self.assertIn('points', info)
        self.assertIsInstance(info['points'], LiDARPoints)
        load_points_transform = LoadPointsFromFile(
            coord_type='DEPTH',
            load_dim=4,
            use_dim=use_dim,
            file_client_args=file_client_args)
        info = load_points_transform(data_info)
        self.assertIsInstance(info['points'], DepthPoints)
        self.assertEqual(info['points'].shape[-1], use_dim)
        load_points_transform = LoadPointsFromFile(
            coord_type='DEPTH',
            load_dim=4,
            use_dim=use_dim,
            shift_height=True,
            file_client_args=file_client_args)
        info = load_points_transform(data_info)
        # extra height dim
        self.assertEqual(info['points'].shape[-1], use_dim + 1)

        repr_str = repr(load_points_transform)
        self.assertIn('shift_height=True', repr_str)
        self.assertIn('use_color=False', repr_str)
        self.assertIn('load_dim=4', repr_str)


class TestLoadAnnotations3D(unittest.TestCase):

    def test_load_points_from_file(self):
        file_client_args = dict(backend='disk')

        load_anns_transform = LoadAnnotations3D(
            with_bbox_3d=True,
            with_label_3d=True,
            file_client_args=file_client_args)
        self.assertIs(load_anns_transform.with_seg, False)
        self.assertIs(load_anns_transform.with_bbox_3d, True)
        self.assertIs(load_anns_transform.with_label_3d, True)
        data_info = create_dummy_data_info()
        info = load_anns_transform(data_info)
        self.assertIn('gt_bboxes_3d', info)
        assert_allclose(info['gt_bboxes_3d'].tensor.sum(),
                        torch.tensor(7.2650))
        self.assertIn('gt_labels_3d', info)
        assert_allclose(info['gt_labels_3d'], torch.tensor([1]))
        repr_str = repr(load_anns_transform)
        self.assertIn('with_bbox_3d=True', repr_str)
        self.assertIn('with_label_3d=True', repr_str)
        self.assertIn('with_bbox_depth=False', repr_str)