formating.py 2.76 KB
Newer Older
yeshenglong1's avatar
yeshenglong1 committed
1
2
3
4
5
6
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet3d.core.points import BasePoints
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor

zhe chen's avatar
zhe chen committed
7

yeshenglong1's avatar
yeshenglong1 committed
8
9
10
11
12
13
14
15
16
17
18
19
@PIPELINES.register_module()
class FormatBundleMap(object):
    """Format data for map tasks and then collect data for model input.

    These fields are formatted as follows.

    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
    - semantic_mask (if exists): (1) to tensor, (2) to DataContainer (stack=True)
    - vectors (if exists): (1) to DataContainer (cpu_only=True)
    - img_metas: (1) to DataContainer (cpu_only=True)
    """

zhe chen's avatar
zhe chen committed
20
21
22
23
    def __init__(self, process_img=True,
                 keys=['img', 'semantic_mask', 'vectors'],
                 meta_keys=['intrinsics', 'extrinsics']):

yeshenglong1's avatar
yeshenglong1 committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        self.process_img = process_img
        self.keys = keys
        self.meta_keys = meta_keys

    def __call__(self, results):
        """Call function to transform and format common fields in results.

        Args:
            results (dict): Result dict contains the data to convert.

        Returns:
            dict: The result dict contains the data that is formatted with
                default bundle.
        """
        # Format 3D data
        if 'points' in results:
            assert isinstance(results['points'], BasePoints)
            results['points'] = DC(results['points'].tensor)

        for key in ['voxels', 'coors', 'voxel_centers', 'num_points']:
            if key not in results:
                continue
            results[key] = DC(to_tensor(results[key]), stack=False)

        if 'img' in results and self.process_img:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
                imgs = [img.transpose(2, 0, 1) for img in results['img']]
                imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
                results['img'] = DC(to_tensor(imgs), stack=True)
            else:
                img = np.ascontiguousarray(results['img'].transpose(2, 0, 1))
                results['img'] = DC(to_tensor(img), stack=True)
zhe chen's avatar
zhe chen committed
57

yeshenglong1's avatar
yeshenglong1 committed
58
59
60
61
62
63
64
        if 'semantic_mask' in results:
            results['semantic_mask'] = DC(to_tensor(results['semantic_mask']), stack=True)

        if 'vectors' in results:
            # vectors may have different sizes
            vectors = results['vectors']
            results['vectors'] = DC(vectors, stack=False, cpu_only=True)
zhe chen's avatar
zhe chen committed
65

yeshenglong1's avatar
yeshenglong1 committed
66
67
68
69
70
71
72
73
74
75
        if 'polys' in results:
            results['polys'] = DC(results['polys'], stack=False, cpu_only=True)

        return results

    def __repr__(self):
        """str: Return a string that describes the module."""
        repr_str = self.__class__.__name__
        repr_str += f'(process_img={self.process_img}, '
        return repr_str