loading.py 2.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os.path as osp

import mmcv
import numpy as np

from mmdet.datasets.registry import PIPELINES


@PIPELINES.register_module
class LoadPointsFromFile(object):

    def __init__(self, points_dim=4, with_reflectivity=True):
        self.points_dim = points_dim
        self.with_reflectivity = with_reflectivity

    def __call__(self, results):
        if results['pts_prefix'] is not None:
            filename = osp.join(results['pts_prefix'],
                                results['img_info']['filename'])
        else:
            filename = results['img_info']['filename']
        points = np.fromfile(
            filename, dtype=np.float32).reshape(-1, self.points_dim)
        results['points'] = points
        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += '(points_dim={})'.format(self.points_dim)
        repr_str += '(points_dim={})'.format(self.with_reflectivity)
        return repr_str


@PIPELINES.register_module
class LoadMultiViewImageFromFiles(object):
    """ Load multi channel images from a list of separate channel files.
    Expects results['filename'] to be a list of filenames
    """

    def __init__(self, to_float32=False, color_type='unchanged'):
        self.to_float32 = to_float32
        self.color_type = color_type

    def __call__(self, results):
        if results['img_prefix'] is not None:
            filename = [
                osp.join(results['img_prefix'], fname)
                for fname in results['img_info']['filename']
            ]
        else:
            filename = results['img_info']['filename']
        img = np.stack(
            [mmcv.imread(name, self.color_type) for name in filename], axis=-1)
        if self.to_float32:
            img = img.astype(np.float32)
        results['filename'] = filename
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        # Set initial values for default meta_keys
        results['pad_shape'] = img.shape
        results['scale_factor'] = 1.0
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
        return results

    def __repr__(self):
        return "{} (to_float32={}, color_type='{}')".format(
            self.__class__.__name__, self.to_float32, self.color_type)