"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "d9351d9a9823176f76190fb3739e5b3b02de4c52"
Commit 1f76f449 authored by Jethong's avatar Jethong
Browse files

Add PGNet

parent 1a087990
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.ski_thin import *
from ppocr.utils.e2e_utils.visual import *
import paddle
import cv2
import time
class PGPostProcess(object):
"""
The post process for SAST.
"""
def __init__(self,
score_thresh=0.5,
nms_thresh=0.2,
sample_pts_num=2,
shrink_ratio_of_width=0.3,
expand_scale=1.0,
tcl_map_thresh=0.5,
**kwargs):
self.result_path = ""
self.valid_set = 'totaltext'
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
p_score, p_border, p_direction, p_char = outs_dict[:4]
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
src_h, src_w, ratio_h, ratio_w = shape_list[0]
if self.valid_set != 'totaltext':
is_curved = False
else:
is_curved = True
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = np.expand_dims(p_char, axis=0)
p_char = paddle.to_tensor(p_char)
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod)
featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0)
t = len(featyre_seq[0])
featyre_seq = paddle.to_tensor(featyre_seq)
l = np.array([[t]]).astype(np.int64)
length = paddle.to_tensor(l)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred1 = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for x in seq_pred1[:seq_len]:
temp_t.append(x)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
print('the length of tcl point is less than 2, repeat')
yx_center_line.append(yx_center_line[-1])
# expand corresponding offset for total-text.
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# for visualization
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
# ndarry: (x, 2)
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
print('expand along width. {}'.format(detected_poly.shape))
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
}
# visualization
# if self.save_visualization:
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
# save detected boxes
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
# if not os.path.exists(txt_dir):
# os.makedirs(txt_dir)
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
# with open(res_file, 'w') as f:
# for i_box, box in enumerate(poly_list):
# seq_str = keep_str_list[i_box]
# box = np.round(box).astype('int32')
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
return data
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.join(__dir__, '..'))
...@@ -49,12 +50,12 @@ class SASTPostProcess(object): ...@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self.shrink_ratio_of_width = shrink_ratio_of_width self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5 # c++ la-nms is faster, but only support python 3.5
self.is_python35 = False self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5: if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True self.is_python35 = True
def point_pair2poly(self, point_pair_list): def point_pair2poly(self, point_pair_list):
""" """
Transfer vertical point_pairs into poly point in clockwise. Transfer vertical point_pairs into poly point in clockwise.
...@@ -66,31 +67,42 @@ class SASTPostProcess(object): ...@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list[idx] = point_pair[0] point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1] point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2) return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
""" """
Generate shrink_quad_along_width. Generate shrink_quad_along_width.
""" """
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
""" """
expand poly along width. expand poly along width.
""" """
point_num = poly.shape[0] point_num = poly.shape[0]
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1], 1.0)
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32) right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \ right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
right_ratio)
poly[0] = left_quad_expand[0] poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1] poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1] poly[point_num // 2 - 1] = right_quad_expand[1]
...@@ -100,7 +112,7 @@ class SASTPostProcess(object): ...@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad.""" """Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
xy_text = xy_text[:, ::-1] # (n, 2) xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis # Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])] xy_text = xy_text[np.argsort(xy_text[:, 1])]
...@@ -112,7 +124,7 @@ class SASTPostProcess(object): ...@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num = int(tvo_map.shape[-1] / 2) point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4 assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map quads = xy_text_tile - tvo_map
return scores, quads, xy_text return scores, quads, xy_text
...@@ -121,14 +133,12 @@ class SASTPostProcess(object): ...@@ -121,14 +133,12 @@ class SASTPostProcess(object):
""" """
compute area of a quad. compute area of a quad.
""" """
edge = [ edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
]
return np.sum(edge) / 2. return np.sum(edge) / 2.
def nms(self, dets): def nms(self, dets):
if self.is_python35: if self.is_python35:
import lanms import lanms
...@@ -141,7 +151,7 @@ class SASTPostProcess(object): ...@@ -141,7 +151,7 @@ class SASTPostProcess(object):
""" """
Cluster pixels in tcl_map based on quads. Cluster pixels in tcl_map based on quads.
""" """
instance_count = quads.shape[0] + 1 # contain background instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1: if instance_count == 1:
return instance_count, instance_label_map return instance_count, instance_label_map
...@@ -149,18 +159,19 @@ class SASTPostProcess(object): ...@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center # predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0] n = xy_text.shape[0]
xy_text = xy_text[:, ::-1] # (n, 2) xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco pred_tc = xy_text - tco
# get gt text center # get gt text center
m = quads.shape[0] m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2) gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) (1, m, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map return instance_count, instance_label_map
...@@ -169,26 +180,47 @@ class SASTPostProcess(object): ...@@ -169,26 +180,47 @@ class SASTPostProcess(object):
""" """
Estimate sample points number. Estimate sample points number.
""" """
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0 eh = (np.linalg.norm(quad[0] - quad[3]) +
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew)) dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num, dense_xy_center_line = xy_text[np.linspace(
endpoint=True, dtype=np.float32).astype(np.int32)] 0,
xy_text.shape[0] - 1,
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] dense_sample_pts_num,
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) endpoint=True,
dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[
1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(
np.linalg.norm(
dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh)) sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num return sample_pts_num
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h, def detect_sast(self,
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0): tcl_map,
tvo_map,
tbo_map,
tco_map,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=0.3,
tcl_map_thresh=0.5,
offset_expand=1.0,
out_strid=4.0):
""" """
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
""" """
# restore quad # restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False) dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets) dets = self.nms(dets)
if dets.shape[0] == 0: if dets.shape[0] == 0:
...@@ -202,7 +234,8 @@ class SASTPostProcess(object): ...@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation # instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map) instance_count, instance_label_map = self.cluster_by_quads_tco(
tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance. # restore single poly with tcl instance.
poly_list = [] poly_list = []
...@@ -212,10 +245,10 @@ class SASTPostProcess(object): ...@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area = quad_areas[instance_idx - 1] q_area = quad_areas[instance_idx - 1]
if q_area < 5: if q_area < 5:
continue continue
# #
len1 = float(np.linalg.norm(quad[0] -quad[1])) len1 = float(np.linalg.norm(quad[0] - quad[1]))
len2 = float(np.linalg.norm(quad[1] -quad[2])) len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2) min_len = min(len1, len2)
if min_len < 3: if min_len < 3:
continue continue
...@@ -225,16 +258,18 @@ class SASTPostProcess(object): ...@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue continue
# filter low confidence instance # filter low confidence instance
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue continue
# sort xy_text # sort xy_text
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0, left_center_pt = np.array(
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) [[(quad[0, 0] + quad[-1, 0]) / 2.0,
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) right_center_pt = np.array(
[[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \ proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1) proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
...@@ -245,33 +280,45 @@ class SASTPostProcess(object): ...@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else: else:
sample_pts_num = self.sample_pts_num sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num, xy_center_line = xy_text[np.linspace(
endpoint=True, dtype=np.float32).astype(np.int32)] 0,
xy_text.shape[0] - 1,
sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
point_pair_list = [] point_pair_list = []
for x, y in xy_center_line: for x, y in xy_center_line:
# get corresponding offset # get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2) offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0: if offset_expand != 1.0:
offset_length = np.linalg.norm(offset, axis=1, keepdims=True) offset_length = np.linalg.norm(
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0) offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal offset = offset + offset_detal
# original point # original point
ori_yx = np.array([y, x], dtype=np.float32) ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2) point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair) point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width # ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list) detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width) detected_poly = self.expand_poly_along_width(detected_poly,
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) shrink_ratio_of_width)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly) poly_list.append(detected_poly)
return poly_list return poly_list
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score'] score_list = outs_dict['f_score']
border_list = outs_dict['f_border'] border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo'] tvo_list = outs_dict['f_tvo']
...@@ -281,20 +328,28 @@ class SASTPostProcess(object): ...@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list = border_list.numpy() border_list = border_list.numpy()
tvo_list = tvo_list.numpy() tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy() tco_list = tco_list.numpy()
img_num = len(shape_list) img_num = len(shape_list)
poly_lists = [] poly_lists = []
for ino in range(img_num): for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0)) p_score = score_list[ino].transpose((1, 2, 0))
p_border = border_list[ino].transpose((1,2,0)) p_border = border_list[ino].transpose((1, 2, 0))
p_tvo = tvo_list[ino].transpose((1,2,0)) p_tvo = tvo_list[ino].transpose((1, 2, 0))
p_tco = tco_list[ino].transpose((1,2,0)) p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino] src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, poly_list = self.detect_sast(
shrink_ratio_of_width=self.shrink_ratio_of_width, p_score,
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale) p_tvo,
p_border,
p_tco,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh,
offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)}) poly_lists.append({'points': np.array(poly_list)})
return poly_lists return poly_lists
This diff is collapsed.
import numpy as np
from shapely.geometry import Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def area(x, y):
polygon = Polygon(np.stack([x, y], axis=1))
return float(polygon.area)
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax = np.max(det_y)
det_xmax = np.max(det_x)
det_ymin = np.min(det_y)
det_xmin = np.min(det_x)
gt_ymax = np.max(gt_y)
gt_xmax = np.max(gt_x)
gt_ymin = np.min(gt_y)
gt_xmin = np.min(gt_x)
all_min_ymax = np.minimum(det_ymax, gt_ymax)
all_max_ymin = np.maximum(det_ymin, gt_ymin)
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
all_min_xmax = np.minimum(det_xmax, gt_xmax)
all_max_xmin = np.maximum(det_xmin, gt_xmin)
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
return intersect_heights * intersect_widths
def area_of_intersection(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.intersection(p2).area)
def area_of_union(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.union(p2).area)
def iou(det_x, det_y, gt_x, gt_y):
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
def iod(det_x, det_y, gt_x, gt_y):
"""
This helper determine the fraction of intersection area over detection area
"""
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area(det_x, det_y) + 1.0)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser): ...@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser):
def parse_args(self, argv=None): def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv) args = super(ArgsParser, self).parse_args(argv)
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert args.config is not None, \ assert args.config is not None, \
"Please specify --config=configure_file_path." "Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt) args.opt = self._parse_opt(args.opt)
...@@ -374,7 +375,8 @@ def preprocess(is_train=False): ...@@ -374,7 +375,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PG'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment