sast_postprocess.py 13.3 KB
Newer Older
MissPenguin's avatar
MissPenguin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
Jethong's avatar
Jethong committed
21

MissPenguin's avatar
MissPenguin committed
22
23
24
25
26
27
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))

import numpy as np
from .locality_aware_nms import nms_locality
MissPenguin's avatar
MissPenguin committed
28
import paddle
MissPenguin's avatar
MissPenguin committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import cv2
import time


class SASTPostProcess(object):
    """
    The post process for SAST.
    """

    def __init__(self,
                 score_thresh=0.5,
                 nms_thresh=0.2,
                 sample_pts_num=2,
                 shrink_ratio_of_width=0.3,
                 expand_scale=1.0,
                 tcl_map_thresh=0.5,
                 **kwargs):

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.sample_pts_num = sample_pts_num
        self.shrink_ratio_of_width = shrink_ratio_of_width
        self.expand_scale = expand_scale
        self.tcl_map_thresh = tcl_map_thresh
Jethong's avatar
Jethong committed
53

MissPenguin's avatar
MissPenguin committed
54
55
56
57
        # c++ la-nms is faster, but only support python 3.5
        self.is_python35 = False
        if sys.version_info.major == 3 and sys.version_info.minor == 5:
            self.is_python35 = True
Jethong's avatar
Jethong committed
58

MissPenguin's avatar
MissPenguin committed
59
60
61
62
63
64
65
66
67
68
69
    def point_pair2poly(self, point_pair_list):
        """
        Transfer vertical point_pairs into poly point in clockwise.
        """
        # constract poly
        point_num = len(point_pair_list) * 2
        point_list = [0] * point_num
        for idx, point_pair in enumerate(point_pair_list):
            point_list[idx] = point_pair[0]
            point_list[point_num - 1 - idx] = point_pair[1]
        return np.array(point_list).reshape(-1, 2)
Jethong's avatar
Jethong committed
70
71
72
73
74

    def shrink_quad_along_width(self,
                                quad,
                                begin_width_ratio=0.,
                                end_width_ratio=1.):
MissPenguin's avatar
MissPenguin committed
75
76
77
        """ 
        Generate shrink_quad_along_width.
        """
Jethong's avatar
Jethong committed
78
79
        ratio_pair = np.array(
            [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
MissPenguin's avatar
MissPenguin committed
80
81
82
        p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
        p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
        return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
Jethong's avatar
Jethong committed
83

MissPenguin's avatar
MissPenguin committed
84
85
86
87
88
    def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
        """
        expand poly along width.
        """
        point_num = poly.shape[0]
Jethong's avatar
Jethong committed
89
90
        left_quad = np.array(
            [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
MissPenguin's avatar
MissPenguin committed
91
        left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
Jethong's avatar
Jethong committed
92
93
94
95
96
97
98
99
100
                     (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
        left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
                                                        1.0)
        right_quad = np.array(
            [
                poly[point_num // 2 - 2], poly[point_num // 2 - 1],
                poly[point_num // 2], poly[point_num // 2 + 1]
            ],
            dtype=np.float32)
MissPenguin's avatar
MissPenguin committed
101
        right_ratio = 1.0 + \
Jethong's avatar
Jethong committed
102
103
104
105
                      shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
                      (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
        right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
                                                         right_ratio)
MissPenguin's avatar
MissPenguin committed
106
107
108
109
110
111
112
113
114
        poly[0] = left_quad_expand[0]
        poly[-1] = left_quad_expand[-1]
        poly[point_num // 2 - 1] = right_quad_expand[1]
        poly[point_num // 2] = right_quad_expand[2]
        return poly

    def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
        """Restore quad."""
        xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
Jethong's avatar
Jethong committed
115
        xy_text = xy_text[:, ::-1]  # (n, 2)
MissPenguin's avatar
MissPenguin committed
116
117
118
119
120
121
122
123
124
125
126

        # Sort the text boxes via the y axis
        xy_text = xy_text[np.argsort(xy_text[:, 1])]

        scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
        scores = scores[:, np.newaxis]

        # Restore
        point_num = int(tvo_map.shape[-1] / 2)
        assert point_num == 4
        tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
Jethong's avatar
Jethong committed
127
        xy_text_tile = np.tile(xy_text, (1, point_num))  # (n, point_num * 2)
MissPenguin's avatar
MissPenguin committed
128
129
130
131
132
133
134
135
        quads = xy_text_tile - tvo_map

        return scores, quads, xy_text

    def quad_area(self, quad):
        """
        compute area of a quad.
        """
Jethong's avatar
Jethong committed
136
137
138
139
        edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
                (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
                (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
                (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
MissPenguin's avatar
MissPenguin committed
140
        return np.sum(edge) / 2.
Jethong's avatar
Jethong committed
141

MissPenguin's avatar
MissPenguin committed
142
143
144
145
146
147
148
149
150
151
152
153
    def nms(self, dets):
        if self.is_python35:
            import lanms
            dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
        else:
            dets = nms_locality(dets, self.nms_thresh)
        return dets

    def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
        """
        Cluster pixels in tcl_map based on quads.
        """
Jethong's avatar
Jethong committed
154
        instance_count = quads.shape[0] + 1  # contain background
MissPenguin's avatar
MissPenguin committed
155
156
157
158
159
160
161
        instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
        if instance_count == 1:
            return instance_count, instance_label_map

        # predict text center
        xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
        n = xy_text.shape[0]
Jethong's avatar
Jethong committed
162
163
        xy_text = xy_text[:, ::-1]  # (n, 2)
        tco = tco_map[xy_text[:, 1], xy_text[:, 0], :]  # (n, 2)
MissPenguin's avatar
MissPenguin committed
164
        pred_tc = xy_text - tco
Jethong's avatar
Jethong committed
165

MissPenguin's avatar
MissPenguin committed
166
167
        # get gt text center
        m = quads.shape[0]
Jethong's avatar
Jethong committed
168
        gt_tc = np.mean(quads, axis=1)  # (m, 2)
MissPenguin's avatar
MissPenguin committed
169

Jethong's avatar
Jethong committed
170
171
172
173
174
        pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
                               (1, m, 1))  # (n, m, 2)
        gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1))  # (n, m, 2)
        dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2)  # (n, m)
        xy_text_assign = np.argmin(dist_mat, axis=1) + 1  # (n,)
MissPenguin's avatar
MissPenguin committed
175
176
177
178
179
180
181
182

        instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
        return instance_count, instance_label_map

    def estimate_sample_pts_num(self, quad, xy_text):
        """
        Estimate sample points number.
        """
Jethong's avatar
Jethong committed
183
184
185
186
        eh = (np.linalg.norm(quad[0] - quad[3]) +
              np.linalg.norm(quad[1] - quad[2])) / 2.0
        ew = (np.linalg.norm(quad[0] - quad[1]) +
              np.linalg.norm(quad[2] - quad[3])) / 2.0
MissPenguin's avatar
MissPenguin committed
187
188

        dense_sample_pts_num = max(2, int(ew))
Jethong's avatar
Jethong committed
189
190
191
192
193
194
195
196
197
198
199
200
        dense_xy_center_line = xy_text[np.linspace(
            0,
            xy_text.shape[0] - 1,
            dense_sample_pts_num,
            endpoint=True,
            dtype=np.float32).astype(np.int32)]

        dense_xy_center_line_diff = dense_xy_center_line[
            1:] - dense_xy_center_line[:-1]
        estimate_arc_len = np.sum(
            np.linalg.norm(
                dense_xy_center_line_diff, axis=1))
MissPenguin's avatar
MissPenguin committed
201
202
203
204

        sample_pts_num = max(2, int(estimate_arc_len / eh))
        return sample_pts_num

Jethong's avatar
Jethong committed
205
206
207
208
209
210
211
212
213
214
215
216
217
    def detect_sast(self,
                    tcl_map,
                    tvo_map,
                    tbo_map,
                    tco_map,
                    ratio_w,
                    ratio_h,
                    src_w,
                    src_h,
                    shrink_ratio_of_width=0.3,
                    tcl_map_thresh=0.5,
                    offset_expand=1.0,
                    out_strid=4.0):
MissPenguin's avatar
MissPenguin committed
218
219
220
221
        """
        first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
        """
        # restore quad
Jethong's avatar
Jethong committed
222
223
        scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
                                                   tvo_map)
MissPenguin's avatar
MissPenguin committed
224
225
226
227
228
229
230
231
232
233
234
235
236
        dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
        dets = self.nms(dets)
        if dets.shape[0] == 0:
            return []
        quads = dets[:, :-1].reshape(-1, 4, 2)

        # Compute quad area
        quad_areas = []
        for quad in quads:
            quad_areas.append(-self.quad_area(quad))

        # instance segmentation
        # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
Jethong's avatar
Jethong committed
237
238
        instance_count, instance_label_map = self.cluster_by_quads_tco(
            tcl_map, tcl_map_thresh, quads, tco_map)
MissPenguin's avatar
MissPenguin committed
239
240
241
242
243
244
245
246
247

        # restore single poly with tcl instance.
        poly_list = []
        for instance_idx in range(1, instance_count):
            xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
            quad = quads[instance_idx - 1]
            q_area = quad_areas[instance_idx - 1]
            if q_area < 5:
                continue
Jethong's avatar
Jethong committed
248

MissPenguin's avatar
MissPenguin committed
249
            #
Jethong's avatar
Jethong committed
250
251
            len1 = float(np.linalg.norm(quad[0] - quad[1]))
            len2 = float(np.linalg.norm(quad[1] - quad[2]))
MissPenguin's avatar
MissPenguin committed
252
253
254
255
256
257
258
259
260
            min_len = min(len1, len2)
            if min_len < 3:
                continue

            # filter small CC
            if xy_text.shape[0] <= 0:
                continue

            # filter low confidence instance
Jethong's avatar
Jethong committed
261
            xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
MissPenguin's avatar
MissPenguin committed
262
            if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
Jethong's avatar
Jethong committed
263
                # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
MissPenguin's avatar
MissPenguin committed
264
265
266
                continue

            # sort xy_text
Jethong's avatar
Jethong committed
267
268
269
270
271
272
            left_center_pt = np.array(
                [[(quad[0, 0] + quad[-1, 0]) / 2.0,
                  (quad[0, 1] + quad[-1, 1]) / 2.0]])  # (1, 2)
            right_center_pt = np.array(
                [[(quad[1, 0] + quad[2, 0]) / 2.0,
                  (quad[1, 1] + quad[2, 1]) / 2.0]])  # (1, 2)
MissPenguin's avatar
MissPenguin committed
273
274
275
276
277
278
279
280
281
282
            proj_unit_vec = (right_center_pt - left_center_pt) / \
                            (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
            proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
            xy_text = xy_text[np.argsort(proj_value)]

            # Sample pts in tcl map
            if self.sample_pts_num == 0:
                sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
            else:
                sample_pts_num = self.sample_pts_num
Jethong's avatar
Jethong committed
283
284
285
286
287
288
            xy_center_line = xy_text[np.linspace(
                0,
                xy_text.shape[0] - 1,
                sample_pts_num,
                endpoint=True,
                dtype=np.float32).astype(np.int32)]
MissPenguin's avatar
MissPenguin committed
289
290
291
292
293
294

            point_pair_list = []
            for x, y in xy_center_line:
                # get corresponding offset
                offset = tbo_map[y, x, :].reshape(2, 2)
                if offset_expand != 1.0:
Jethong's avatar
Jethong committed
295
296
297
298
299
300
                    offset_length = np.linalg.norm(
                        offset, axis=1, keepdims=True)
                    expand_length = np.clip(
                        offset_length * (offset_expand - 1),
                        a_min=0.5,
                        a_max=3.0)
MissPenguin's avatar
MissPenguin committed
301
                    offset_detal = offset / offset_length * expand_length
Jethong's avatar
Jethong committed
302
303
                    offset = offset + offset_detal
                    # original point
MissPenguin's avatar
MissPenguin committed
304
                ori_yx = np.array([y, x], dtype=np.float32)
Jethong's avatar
Jethong committed
305
306
                point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
                    [ratio_w, ratio_h]).reshape(-1, 2)
MissPenguin's avatar
MissPenguin committed
307
308
309
310
                point_pair_list.append(point_pair)

            # ndarry: (x, 2), expand poly along width
            detected_poly = self.point_pair2poly(point_pair_list)
Jethong's avatar
Jethong committed
311
312
313
314
315
316
            detected_poly = self.expand_poly_along_width(detected_poly,
                                                         shrink_ratio_of_width)
            detected_poly[:, 0] = np.clip(
                detected_poly[:, 0], a_min=0, a_max=src_w)
            detected_poly[:, 1] = np.clip(
                detected_poly[:, 1], a_min=0, a_max=src_h)
MissPenguin's avatar
MissPenguin committed
317
318
319
320
            poly_list.append(detected_poly)

        return poly_list

Jethong's avatar
Jethong committed
321
    def __call__(self, outs_dict, shape_list):
MissPenguin's avatar
MissPenguin committed
322
323
324
325
        score_list = outs_dict['f_score']
        border_list = outs_dict['f_border']
        tvo_list = outs_dict['f_tvo']
        tco_list = outs_dict['f_tco']
MissPenguin's avatar
MissPenguin committed
326
327
328
329
330
        if isinstance(score_list, paddle.Tensor):
            score_list = score_list.numpy()
            border_list = border_list.numpy()
            tvo_list = tvo_list.numpy()
            tco_list = tco_list.numpy()
Jethong's avatar
Jethong committed
331

MissPenguin's avatar
MissPenguin committed
332
333
334
        img_num = len(shape_list)
        poly_lists = []
        for ino in range(img_num):
Jethong's avatar
Jethong committed
335
336
337
338
            p_score = score_list[ino].transpose((1, 2, 0))
            p_border = border_list[ino].transpose((1, 2, 0))
            p_tvo = tvo_list[ino].transpose((1, 2, 0))
            p_tco = tco_list[ino].transpose((1, 2, 0))
MissPenguin's avatar
MissPenguin committed
339
340
            src_h, src_w, ratio_h, ratio_w = shape_list[ino]

Jethong's avatar
Jethong committed
341
342
343
344
345
346
347
348
349
350
351
352
            poly_list = self.detect_sast(
                p_score,
                p_tvo,
                p_border,
                p_tco,
                ratio_w,
                ratio_h,
                src_w,
                src_h,
                shrink_ratio_of_width=self.shrink_ratio_of_width,
                tcl_map_thresh=self.tcl_map_thresh,
                offset_expand=self.expand_scale)
MissPenguin's avatar
MissPenguin committed
353
354
355
            poly_lists.append({'points': np.array(poly_list)})

        return poly_lists