pse_postprocess.py 3.95 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
WenmuZhou's avatar
WenmuZhou committed
14
15
16
17
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
"""
WenmuZhou's avatar
WenmuZhou committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import cv2
import paddle
from paddle.nn import functional as F

from ppocr.postprocess.pse_postprocess.pse import pse


class PSEPostProcess(object):
    """
    The post process for PSE.
    """

    def __init__(self,
                 thresh=0.5,
                 box_thresh=0.85,
                 min_area=16,
                 box_type='box',
                 scale=4,
                 **kwargs):
        assert box_type in ['box', 'poly'], 'Only box and poly is supported'
        self.thresh = thresh
        self.box_thresh = box_thresh
        self.min_area = min_area
        self.box_type = box_type
        self.scale = scale

    def __call__(self, outs_dict, shape_list):
        pred = outs_dict['maps']
        if not isinstance(pred, paddle.Tensor):
            pred = paddle.to_tensor(pred)
WenmuZhou's avatar
WenmuZhou committed
54
55
        pred = F.interpolate(
            pred, scale_factor=4 // self.scale, mode='bilinear')
WenmuZhou's avatar
WenmuZhou committed
56
57
58
59
60
61
62
63
64
65
66
67

        score = F.sigmoid(pred[:, 0, :, :])

        kernels = (pred > self.thresh).astype('float32')
        text_mask = kernels[:, 0, :, :]
        kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask

        score = score.numpy()
        kernels = kernels.numpy().astype(np.uint8)

        boxes_batch = []
        for batch_index in range(pred.shape[0]):
WenmuZhou's avatar
WenmuZhou committed
68
69
70
            boxes, scores = self.boxes_from_bitmap(score[batch_index],
                                                   kernels[batch_index],
                                                   shape_list[batch_index])
WenmuZhou's avatar
WenmuZhou committed
71
72
73
74

            boxes_batch.append({'points': boxes, 'scores': scores})
        return boxes_batch

WenmuZhou's avatar
WenmuZhou committed
75
    def boxes_from_bitmap(self, score, kernels, shape):
WenmuZhou's avatar
WenmuZhou committed
76
        label = pse(kernels, self.min_area)
WenmuZhou's avatar
WenmuZhou committed
77
        return self.generate_box(score, label, shape)
WenmuZhou's avatar
WenmuZhou committed
78

WenmuZhou's avatar
WenmuZhou committed
79
80
    def generate_box(self, score, label, shape):
        src_h, src_w, ratio_h, ratio_w = shape
WenmuZhou's avatar
WenmuZhou committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        label_num = np.max(label) + 1

        boxes = []
        scores = []
        for i in range(1, label_num):
            ind = label == i
            points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]

            if points.shape[0] < self.min_area:
                label[ind] = 0
                continue

            score_i = np.mean(score[ind])
            if score_i < self.box_thresh:
                label[ind] = 0
                continue

            if self.box_type == 'box':
                rect = cv2.minAreaRect(points)
                bbox = cv2.boxPoints(rect)
            elif self.box_type == 'poly':
                box_height = np.max(points[:, 1]) + 10
                box_width = np.max(points[:, 0]) + 10

                mask = np.zeros((box_height, box_width), np.uint8)
                mask[points[:, 1], points[:, 0]] = 255

WenmuZhou's avatar
WenmuZhou committed
108
109
                contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                               cv2.CHAIN_APPROX_SIMPLE)
WenmuZhou's avatar
WenmuZhou committed
110
111
112
113
                bbox = np.squeeze(contours[0], 1)
            else:
                raise NotImplementedError

WenmuZhou's avatar
WenmuZhou committed
114
115
            bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
            bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
WenmuZhou's avatar
WenmuZhou committed
116
117
118
            boxes.append(bbox)
            scores.append(score_i)
        return boxes, scores