vectorize.py 5.24 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
2
from typing import Dict, List, Tuple, Union

yeshenglong1's avatar
yeshenglong1 committed
3
4
5
import numpy as np
from mmdet.datasets.builder import PIPELINES
from numpy.typing import NDArray
zhe chen's avatar
zhe chen committed
6
7
from shapely.geometry import LineString

yeshenglong1's avatar
yeshenglong1 committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

@PIPELINES.register_module(force=True)
class VectorizeMap(object):
    """Generate vectoized map and put into `semantic_mask` key.
    Concretely, shapely geometry objects are converted into sample points (ndarray).
    We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method.

    Args:
        roi_size (tuple or list): bev range .
        normalize (bool): whether to normalize points to range (0, 1).
        coords_dim (int): dimension of point coordinates.
        simplify (bool): whether to use simpily function. If true, `sample_num` \
            and `sample_dist` will be ignored.
        sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore.
        sample_dist (float): interpolate distance. Set to -1 to ignore.
    """

zhe chen's avatar
zhe chen committed
25
26
    def __init__(self,
                 roi_size: Union[Tuple, List],
yeshenglong1's avatar
yeshenglong1 committed
27
28
                 normalize: bool,
                 coords_dim: int,
zhe chen's avatar
zhe chen committed
29
30
31
32
                 simplify: bool = False,
                 sample_num: int = -1,
                 sample_dist: float = -1,
                 ):
yeshenglong1's avatar
yeshenglong1 committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        self.coords_dim = coords_dim
        self.sample_num = sample_num
        self.sample_dist = sample_dist
        self.roi_size = np.array(roi_size)
        self.normalize = normalize
        self.simplify = simplify
        self.sample_fn = None

        if sample_dist > 0:
            assert sample_num < 0 and not simplify
            self.sample_fn = self.interp_fixed_dist
        if sample_num > 0:
            assert sample_dist < 0 and not simplify
            self.sample_fn = self.interp_fixed_num

    def interp_fixed_num(self, line: LineString) -> NDArray:
        ''' Interpolate a line to fixed number of points.
zhe chen's avatar
zhe chen committed
50

yeshenglong1's avatar
yeshenglong1 committed
51
52
        Args:
            line (LineString): line
zhe chen's avatar
zhe chen committed
53

yeshenglong1's avatar
yeshenglong1 committed
54
55
56
57
58
        Returns:
            points (array): interpolated points, shape (N, 2)
        '''

        distances = np.linspace(0, line.length, self.sample_num)
zhe chen's avatar
zhe chen committed
59
60
        sampled_points = np.array([list(line.interpolate(distance).coords)
                                   for distance in distances]).squeeze()
yeshenglong1's avatar
yeshenglong1 committed
61
62
63
64
65

        return sampled_points

    def interp_fixed_dist(self, line: LineString) -> NDArray:
        ''' Interpolate a line at fixed interval.
zhe chen's avatar
zhe chen committed
66

yeshenglong1's avatar
yeshenglong1 committed
67
68
        Args:
            line (LineString): line
zhe chen's avatar
zhe chen committed
69

yeshenglong1's avatar
yeshenglong1 committed
70
71
72
73
74
75
        Returns:
            points (array): interpolated points, shape (N, 2)
        '''

        distances = list(np.arange(self.sample_dist, line.length, self.sample_dist))
        # make sure to sample at least two points when sample_dist > line.length
zhe chen's avatar
zhe chen committed
76
77
        distances = [0, ] + distances + [line.length, ]

yeshenglong1's avatar
yeshenglong1 committed
78
        sampled_points = np.array([list(line.interpolate(distance).coords)
zhe chen's avatar
zhe chen committed
79
80
                                   for distance in distances]).squeeze()

yeshenglong1's avatar
yeshenglong1 committed
81
        return sampled_points
zhe chen's avatar
zhe chen committed
82

yeshenglong1's avatar
yeshenglong1 committed
83
    def get_vectorized_lines(self, map_geoms: Dict) -> Dict:
zhe chen's avatar
zhe chen committed
84
        ''' Vectorize map elements. Iterate over the input dict and apply the
yeshenglong1's avatar
yeshenglong1 committed
85
        specified sample funcion.
zhe chen's avatar
zhe chen committed
86

yeshenglong1's avatar
yeshenglong1 committed
87
88
        Args:
            line (LineString): line
zhe chen's avatar
zhe chen committed
89

yeshenglong1's avatar
yeshenglong1 committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        Returns:
            vectors (array): dict of vectorized map elements.
        '''

        vectors = {}
        for label, geom_list in map_geoms.items():
            vectors[label] = []
            for geom in geom_list:
                if geom.geom_type == 'LineString':
                    geom = LineString(np.array(geom.coords)[:, :self.coords_dim])
                    if self.simplify:
                        line = geom.simplify(0.2, preserve_topology=True)
                        line = np.array(line.coords)
                    elif self.sample_fn:
                        line = self.sample_fn(geom)
                    else:
                        line = np.array(line.coords)

                    if self.normalize:
                        line = self.normalize_line(line)
                    vectors[label].append(line)

                elif geom.geom_type == 'Polygon':
                    # polygon objects will not be vectorized
                    continue
zhe chen's avatar
zhe chen committed
115

yeshenglong1's avatar
yeshenglong1 committed
116
117
118
                else:
                    raise ValueError('map geoms must be either LineString or Polygon!')
        return vectors
zhe chen's avatar
zhe chen committed
119

yeshenglong1's avatar
yeshenglong1 committed
120
121
    def normalize_line(self, line: NDArray) -> NDArray:
        ''' Convert points to range (0, 1).
zhe chen's avatar
zhe chen committed
122

yeshenglong1's avatar
yeshenglong1 committed
123
124
        Args:
            line (LineString): line
zhe chen's avatar
zhe chen committed
125

yeshenglong1's avatar
yeshenglong1 committed
126
127
128
129
        Returns:
            normalized (array): normalized points.
        '''

zhe chen's avatar
zhe chen committed
130
        origin = -np.array([self.roi_size[0] / 2, self.roi_size[1] / 2])
yeshenglong1's avatar
yeshenglong1 committed
131
132
133
134
135
136
137
138

        line[:, :2] = line[:, :2] - origin

        # transform from range [0, 1] to (0, 1)
        eps = 2
        line[:, :2] = line[:, :2] / (self.roi_size + eps)

        return line
zhe chen's avatar
zhe chen committed
139

yeshenglong1's avatar
yeshenglong1 committed
140
141
142
143
144
145
146
147
148
149
    def __call__(self, input_dict):
        map_geoms = input_dict['map_geoms']

        input_dict['vectors'] = self.get_vectorized_lines(map_geoms)
        return input_dict

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += f'(simplify={self.simplify}, '
        repr_str += f'sample_num={self.sample_num}), '
zhe chen's avatar
zhe chen committed
150
        repr_str += f'sample_dist={self.sample_dist}), '
yeshenglong1's avatar
yeshenglong1 committed
151
152
153
154
        repr_str += f'roi_size={self.roi_size})'
        repr_str += f'normalize={self.normalize})'
        repr_str += f'coords_dim={self.coords_dim})'

zhe chen's avatar
zhe chen committed
155
        return repr_str