Commit 1f76f449 authored by Jethong's avatar Jethong
Browse files

Add PGNet

parent 1a087990
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.ski_thin import *
from ppocr.utils.e2e_utils.visual import *
import paddle
import cv2
import time
class PGPostProcess(object):
"""
The post process for SAST.
"""
def __init__(self,
score_thresh=0.5,
nms_thresh=0.2,
sample_pts_num=2,
shrink_ratio_of_width=0.3,
expand_scale=1.0,
tcl_map_thresh=0.5,
**kwargs):
self.result_path = ""
self.valid_set = 'totaltext'
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
p_score, p_border, p_direction, p_char = outs_dict[:4]
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
src_h, src_w, ratio_h, ratio_w = shape_list[0]
if self.valid_set != 'totaltext':
is_curved = False
else:
is_curved = True
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = np.expand_dims(p_char, axis=0)
p_char = paddle.to_tensor(p_char)
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod)
featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0)
t = len(featyre_seq[0])
featyre_seq = paddle.to_tensor(featyre_seq)
l = np.array([[t]]).astype(np.int64)
length = paddle.to_tensor(l)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred1 = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for x in seq_pred1[:seq_len]:
temp_t.append(x)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
print('the length of tcl point is less than 2, repeat')
yx_center_line.append(yx_center_line[-1])
# expand corresponding offset for total-text.
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# for visualization
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
# ndarry: (x, 2)
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
print('expand along width. {}'.format(detected_poly.shape))
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
}
# visualization
# if self.save_visualization:
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
# save detected boxes
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
# if not os.path.exists(txt_dir):
# os.makedirs(txt_dir)
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
# with open(res_file, 'w') as f:
# for i_box, box in enumerate(poly_list):
# seq_str = keep_str_list[i_box]
# box = np.round(box).astype('int32')
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
return data
......@@ -18,6 +18,7 @@ from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
......@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def point_pair2poly(self, point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
......@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
......@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
xy_text = xy_text[:, ::-1] # (n, 2)
xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])]
......@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map
return scores, quads, xy_text
......@@ -121,14 +133,12 @@ class SASTPostProcess(object):
"""
compute area of a quad.
"""
edge = [
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
]
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
return np.sum(edge) / 2.
def nms(self, dets):
if self.is_python35:
import lanms
......@@ -141,7 +151,7 @@ class SASTPostProcess(object):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count = quads.shape[0] + 1 # contain background
instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1:
return instance_count, instance_label_map
......@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0]
xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco
# get gt text center
m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2)
gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
(1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map
......@@ -169,26 +180,47 @@ class SASTPostProcess(object):
"""
Estimate sample points number.
"""
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
eh = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
dense_xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
dense_sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[
1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(
np.linalg.norm(
dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
def detect_sast(self,
tcl_map,
tvo_map,
tbo_map,
tco_map,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=0.3,
tcl_map_thresh=0.5,
offset_expand=1.0,
out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
......@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
instance_count, instance_label_map = self.cluster_by_quads_tco(
tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
......@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area = quad_areas[instance_idx - 1]
if q_area < 5:
continue
#
len1 = float(np.linalg.norm(quad[0] -quad[1]))
len2 = float(np.linalg.norm(quad[1] -quad[2]))
len1 = float(np.linalg.norm(quad[0] - quad[1]))
len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
......@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue
# filter low confidence instance
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
left_center_pt = np.array(
[[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array(
[[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
......@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
xy_center_line = xy_text[np.linspace(
0,
xy_text.shape[0] - 1,
sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
# original point
offset = offset + offset_detal
# original point
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
detected_poly = self.expand_poly_along_width(detected_poly,
shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
def __call__(self, outs_dict, shape_list):
def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
......@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list = border_list.numpy()
tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy()
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0))
p_border = border_list[ino].transpose((1,2,0))
p_tvo = tvo_list[ino].transpose((1,2,0))
p_tco = tco_list[ino].transpose((1,2,0))
p_score = score_list[ino].transpose((1, 2, 0))
p_border = border_list[ino].transpose((1, 2, 0))
p_tvo = tvo_list[ino].transpose((1, 2, 0))
p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
poly_list = self.detect_sast(
p_score,
p_tvo,
p_border,
p_tco,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh,
offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)})
return poly_lists
from os import listdir
import os, sys
from scipy import io
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
except Exception:
# python3
range = range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
def get_socre(gt_dict, pred_dict):
# allInputs = listdir(input_dir)
allInputs = 1
def input_reading_mod(pred_dict, input):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
# for i in range(len(points)):
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict, gt_id):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
print
"liushanshan gt[1] = {}".format(gt[1])
print
"liushanshan gt[2] = {}".format(gt[2])
print
"liushanshan gt[3] = {}".format(gt[3])
print
"liushanshan gt[4] = {}".format(gt[4])
print
"liushanshan gt[5] = {}".format(gt[5])
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
# detection = detection.split(',')
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
global_tp = 0
global_fp = 0
global_fn = 0
global_sigma = []
global_tau = []
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
global_pred_str = []
global_gt_str = []
###############################################################################
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
print(input_id)
detections = input_reading_mod(pred_dict, input_id)
# print "liushanshan detections = {}".format(detections)
groundtruths = gt_reading_mod(gt_dict, input_id)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma.append(local_sigma_table)
global_tau.append(local_tau_table)
global_pred_str.append(local_pred_str)
global_gt_str.append(local_gt_str)
print
"liushanshan global_pred_str = {}".format(global_pred_str)
print
"liushanshan global_gt_str = {}".format(global_gt_str)
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
single_data = {}
for idx in range(len(global_sigma)):
# print(allInputs[idx])
local_sigma_table = global_sigma[idx]
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
# fid = open(fid_path, 'a+')
try:
local_precision = local_accumulative_precision / num_det
except ZeroDivisionError:
local_precision = 0
try:
local_recall = local_accumulative_recall / num_gt
except ZeroDivisionError:
local_recall = 0
try:
local_f_score = 2 * local_precision * local_recall / (
local_precision + local_recall)
except ZeroDivisionError:
local_f_score = 0
# temp = ('%s: Recall=%.4f, Precision=%.4f, f_score=%.4f\n' % (
# allInputs[idx], local_recall, local_precision, local_f_score))
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
single_data["recall"] = local_recall
single_data['precision'] = local_precision
single_data['f_score'] = local_f_score
return single_data
def combine_results(all_data):
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
global_sigma = []
global_tau = []
global_pred_str = []
global_gt_str = []
for data in all_data:
global_sigma.append(data['sigma'][0])
global_tau.append(data['global_tau'][0])
global_pred_str.append(data['global_pred_str'][0])
global_gt_str.append(data['global_gt_str'][0])
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
for idx in range(len(global_sigma)):
local_sigma_table = np.array(global_sigma[idx])
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
try:
recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError:
recall = 0
try:
precision = global_accumulative_precision / total_num_det
except ZeroDivisionError:
precision = 0
try:
f_score = 2 * precision * recall / (precision + recall)
except ZeroDivisionError:
f_score = 0
try:
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
except ZeroDivisionError:
seqerr = 1
try:
recall_e2e = float(hit_str_count) / total_num_gt
except ZeroDivisionError:
recall_e2e = 0
try:
precision_e2e = float(hit_str_count) / total_num_det
except ZeroDivisionError:
precision_e2e = 0
try:
f_score_e2e = 2 * precision_e2e * recall_e2e / (
precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
'total_num_gt': total_num_gt,
'total_num_det': total_num_det,
'global_accumulative_recall': global_accumulative_recall,
'hit_str_count': hit_str_count,
'recall': recall,
'precision': precision,
'f_score': f_score,
'seqerr': seqerr,
'recall_e2e': recall_e2e,
'precision_e2e': precision_e2e,
'f_score_e2e': f_score_e2e
}
return final
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
# {'points': np.array(a), 'text': 'ccf'}]
# result = []
# for i in range(2):
# result.append(get_socre(gt_dict, pred_dict))
# print(111)
# a = combine_results(result)
# print(a)
import numpy as np
from shapely.geometry import Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def area(x, y):
polygon = Polygon(np.stack([x, y], axis=1))
return float(polygon.area)
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax = np.max(det_y)
det_xmax = np.max(det_x)
det_ymin = np.min(det_y)
det_xmin = np.min(det_x)
gt_ymax = np.max(gt_y)
gt_xmax = np.max(gt_x)
gt_ymin = np.min(gt_y)
gt_xmin = np.min(gt_x)
all_min_ymax = np.minimum(det_ymax, gt_ymax)
all_max_ymin = np.maximum(det_ymin, gt_ymin)
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
all_min_xmax = np.minimum(det_xmax, gt_xmax)
all_max_xmin = np.maximum(det_xmin, gt_xmin)
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
return intersect_heights * intersect_widths
def area_of_intersection(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.intersection(p2).area)
def area_of_union(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.union(p2).area)
def iou(det_x, det_y, gt_x, gt_y):
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
def iod(det_x, det_y, gt_x, gt_y):
"""
This helper determine the fraction of intersection area over detection area
"""
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area(det_x, det_y) + 1.0)
from os import listdir
import os, sys
from scipy import io
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
except Exception:
# python3
range = range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
global_tp = 0
global_fp = 0
global_fn = 0
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
def get_socre(gt_dict, pred_dict):
# allInputs = listdir(input_dir)
allInputs = 1
global_pred_str = []
global_gt_str = []
global_sigma = []
global_tau = []
def input_reading_mod(pred_dict, input):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
# for i in range(len(points)):
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict, gt_id):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
print
"liushanshan gt[1] = {}".format(gt[1])
print
"liushanshan gt[2] = {}".format(gt[2])
print
"liushanshan gt[3] = {}".format(gt[3])
print
"liushanshan gt[4] = {}".format(gt[4])
print
"liushanshan gt[5] = {}".format(gt[5])
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
# detection = detection.split(',')
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
###############################################################################
single_data = {}
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
print(input_id)
detections = input_reading_mod(pred_dict, input_id)
# print "liushanshan detections = {}".format(detections)
groundtruths = gt_reading_mod(gt_dict, input_id)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma.append(local_sigma_table)
global_tau.append(local_tau_table)
global_pred_str.append(local_pred_str)
global_gt_str.append(local_gt_str)
print
"liushanshan global_pred_str = {}".format(global_pred_str)
print
"liushanshan global_gt_str = {}".format(global_gt_str)
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
return single_data
def combine_results(all_data):
global_sigma, global_tau, global_pred_str, global_gt_str = [], [], [], []
for data in all_data:
global_sigma.append(data['sigma'])
global_tau.append(data['global_tau'])
global_pred_str.append(data['global_pred_str'])
global_gt_str.append(data['global_gt_str'])
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
for idx in range(len(global_sigma)):
# print(allInputs[idx])
local_sigma_table = np.array(global_sigma[idx])
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
try:
recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError:
recall = 0
try:
precision = global_accumulative_precision / total_num_det
except ZeroDivisionError:
precision = 0
try:
f_score = 2 * precision * recall / (precision + recall)
except ZeroDivisionError:
f_score = 0
try:
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
except ZeroDivisionError:
seqerr = 1
try:
recall_e2e = float(hit_str_count) / total_num_gt
except ZeroDivisionError:
recall_e2e = 0
try:
precision_e2e = float(hit_str_count) / total_num_det
except ZeroDivisionError:
precision_e2e = 0
try:
f_score_e2e = 2 * precision_e2e * recall_e2e / (
precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
'total_num_gt': total_num_gt,
'total_num_det': total_num_det,
'global_accumulative_recall': global_accumulative_recall,
'hit_str_count': hit_str_count,
'recall': recall,
'precision': precision,
'f_score': f_score,
'seqerr': seqerr,
'recall_e2e': recall_e2e,
'precision_e2e': precision_e2e,
'f_score_e2e': f_score_e2e
}
return final
# def combine_results(all_data):
# tr = 0.7
# tp = 0.6
# fsc_k = 0.8
# k = 2
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
# for data in all_data:
# global_sigma.append(data['sigma'])
# global_tau.append(data['global_tau'])
# global_pred_str.append(data['global_pred_str'])
# global_gt_str.append(data['global_gt_str'])
#
# global_accumulative_recall = 0
# global_accumulative_precision = 0
# total_num_gt = 0
# total_num_det = 0
# hit_str_count = 0
# hit_count = 0
#
# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr)
# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0]
# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp)
# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0]
#
# det_matching_qualified_sigma_candidates = np.where(
# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr)
# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0]
# det_matching_qualified_tau_candidates = np.where(
# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp)
# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0]
#
# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
# (det_matching_num_qualified_sigma_candidates == 1) and (
# det_matching_num_qualified_tau_candidates == 1):
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# # recg start
# print
# "liushanshan one to one det_id = {}".format(matched_det_id)
# print
# "liushanshan one to one gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
# print
# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# det_flag[0, matched_det_id] = 1
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# # skip the following if the groundtruth was matched
# if gt_flag[0, gt_id] > 0:
# continue
#
# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
#
# if num_non_zero_in_sigma >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0))
# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
#
# if num_qualified_tau_candidates == 1:
# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr):
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + fsc_k
# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# local_accumulative_recall = local_accumulative_recall + fsc_k
# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for det_id in range(num_det):
# # skip the following if the detection was matched
# if det_flag[0, det_id] > 0:
# continue
#
# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
# num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
#
# if num_non_zero_in_tau >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
#
# if num_qualified_sigma_candidates == 1:
# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, qualified_sigma_candidates] = 1
# det_flag[0, det_id] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if ele_gt_id not in global_gt_str[idy]:
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# # recg end
# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp):
# det_flag[0, det_id] = 1
# gt_flag[0, qualified_sigma_candidates] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if not global_gt_str[idy].has_key(ele_gt_id):
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# else:
# print
# 'no match'
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# global_accumulative_precision = global_accumulative_precision + fsc_k
#
# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# local_accumulative_precision = local_accumulative_precision + fsc_k
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# for idx in range(len(global_sigma)):
# local_sigma_table = np.array(global_sigma[idx])
# local_tau_table = np.array(global_tau[idx])
#
# num_gt = local_sigma_table.shape[0]
# num_det = local_sigma_table.shape[1]
#
# total_num_gt = total_num_gt + num_gt
# total_num_det = total_num_det + num_det
#
# local_accumulative_recall = 0
# local_accumulative_precision = 0
# gt_flag = np.zeros((1, num_gt))
# det_flag = np.zeros((1, num_det))
#
# #######first check for one-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
#
# hit_str_count += hit_str_num
# #######then check for one-to-many case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# hit_str_count += hit_str_num
# #######then check for many-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# try:
# recall = global_accumulative_recall / total_num_gt
# except ZeroDivisionError:
# recall = 0
#
# try:
# precision = global_accumulative_precision / total_num_det
# except ZeroDivisionError:
# precision = 0
#
# try:
# f_score = 2 * precision * recall / (precision + recall)
# except ZeroDivisionError:
# f_score = 0
#
# try:
# seqerr = 1 - float(hit_str_count) / global_accumulative_recall
# except ZeroDivisionError:
# seqerr = 1
#
# try:
# recall_e2e = float(hit_str_count) / total_num_gt
# except ZeroDivisionError:
# recall_e2e = 0
#
# try:
# precision_e2e = float(hit_str_count) / total_num_det
# except ZeroDivisionError:
# precision_e2e = 0
#
# try:
# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
# except ZeroDivisionError:
# f_score_e2e = 0
#
# final = {
# 'total_num_gt': total_num_gt,
# 'total_num_det': total_num_det,
# 'global_accumulative_recall': global_accumulative_recall,
# 'hit_str_count': hit_str_count,
# 'recall': recall,
# 'precision': precision,
# 'f_score': f_score,
# 'seqerr': seqerr,
# 'recall_e2e': recall_e2e,
# 'precision_e2e': precision_e2e,
# 'f_score_e2e': f_score_e2e
# }
# return final
a = [
1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620,
1659, 620, 1654, 681, 1631, 680, 1618, 681, 1606, 681, 1594, 681, 1584, 682,
1573, 685, 1542, 694
]
gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
pred_dict = [{
'points': np.array(a),
'text': 'ccc'
}, {
'points': np.array(a),
'text': 'ccf'
}]
result = []
result.append(get_socre(gt_dict, gt_dict))
a = combine_results(result)
print(a)
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cv2
import time
import math
import numpy as np
from itertools import groupby
from ppocr.utils.e2e_utils.ski_thin import thin
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, logits_map,
keep_blank_in_idxs=True):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
decoder_results.append(res)
return decoder_results
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def generate_pivot_list_curved(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_expand=True,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
# add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
# pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
if __name__ == '__main__':
np.random.seed(0)
import time
logits_map = np.random.random([10, 20, 33])
# a list of [x, y]
instance_gather_info_1 = [(2, 3), (2, 4), (3, 5)]
instance_gather_info_2 = [(15, 6), (15, 7), (18, 8)]
instance_gather_info_3 = [(8, 8), (8, 8), (8, 8)]
gather_info_list = [
instance_gather_info_1, instance_gather_info_2, instance_gather_info_3
]
time0 = time.time()
res = ctc_decoder_for_image(
gather_info_list, logits_map, keep_blank_in_idxs=True)
print(res)
print('cost {}'.format(time.time() - time0))
print('--' * 20)
"""
Algorithms for computing the skeleton of a binary image
"""
import numpy as np
from scipy import ndimage as ndi
G123_LUT = np.array(
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0
],
dtype=np.bool)
G123P_LUT = np.array(
[
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
dtype=np.bool)
def thin(image, max_iter=None):
"""
Perform morphological thinning of a binary image.
Parameters
----------
image : binary (M, N) ndarray
The image to be thinned.
max_iter : int, number of iterations, optional
Regardless of the value of this parameter, the thinned image
is returned immediately if an iteration produces no change.
If this parameter is specified it thus sets an upper bound on
the number of iterations performed.
Returns
-------
out : ndarray of bool
Thinned image.
See also
--------
skeletonize, medial_axis
Notes
-----
This algorithm [1]_ works by making multiple passes over the image,
removing pixels matching a set of criteria designed to thin
connected regions while preserving eight-connected components and
2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
correlates the intermediate skeleton image with a neighborhood mask,
then looks up each neighborhood in a lookup table indicating whether
the central pixel should be deleted in that sub-iteration.
References
----------
.. [1] Z. Guo and R. W. Hall, "Parallel thinning with
two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
pp. 359-373, 1989. :DOI:`10.1145/62065.62074`
.. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
Methodologies-A Comprehensive Survey," IEEE Transactions on
Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
p. 879, 1992. :DOI:`10.1109/34.161346`
Examples
--------
>>> square = np.zeros((7, 7), dtype=np.uint8)
>>> square[1:-1, 2:-2] = 1
>>> square[0, 1] = 1
>>> square
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
>>> skel = thin(square)
>>> skel.astype(np.uint8)
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
"""
# convert image to uint8 with values in {0, 1}
skel = np.asanyarray(image, dtype=bool).astype(np.uint8)
# neighborhood mask
mask = np.array([[8, 4, 2], [16, 0, 1], [32, 64, 128]], dtype=np.uint8)
# iterate until convergence, up to the iteration limit
max_iter = max_iter or np.inf
n_iter = 0
n_pts_old, n_pts_new = np.inf, np.sum(skel)
while n_pts_old != n_pts_new and n_iter < max_iter:
n_pts_old = n_pts_new
# perform the two "subiterations" described in the paper
for lut in [G123_LUT, G123P_LUT]:
# correlate image with neighborhood mask
N = ndi.correlate(skel, mask, mode='constant')
# take deletion decision from this subiteration's LUT
D = np.take(lut, N)
# perform deletion
skel[D] = 0
n_pts_new = np.sum(skel) # count points after thinning
n_iter += 1
return skel.astype(np.bool)
import os
import numpy as np
import cv2
import time
def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im):
"""
"""
result_path = './out'
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
valid_set = 'partvgg'
gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
if valid_set == 'partvgg':
tokens = line.strip().split('\t')[0].split(',')
# tokens = line.strip().split(',')
coords = tokens[:]
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, 4, 2)
elif valid_set == 'totaltext':
tokens = line.strip().split('\t')[0].split(',')
coords = tokens[:]
coords_len = len(coords) / 2
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 0, 0),
thickness=2)
for detected_poly, recognized_str in zip(poly_list, seq_strs):
cv2.polylines(
vis_det_img,
np.array(detected_poly[np.newaxis, ...]).astype(np.int32),
isClosed=True,
color=(0, 0, 255),
thickness=2)
cv2.putText(
vis_det_img,
recognized_str,
org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix),
vis_det_img)
def visualization_output(src_image,
f_tcl,
f_chars,
output_dir,
image_prefix=None):
"""
"""
# restore BGR image, CHW -> HWC
im_mean = [0.485, 0.456, 0.406]
im_std = [0.229, 0.224, 0.225]
im_mean = np.array(im_mean).reshape((3, 1, 1))
im_std = np.array(im_std).reshape((3, 1, 1))
src_image *= im_std
src_image += im_mean
src_image = src_image.transpose([1, 2, 0])
src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB
H, W, _ = src_image.shape
file_prefix = image_prefix if image_prefix is not None else str(
int(time.time() * 1000))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# visualization f_tcl
tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg')
vis_tcl_img = src_image.copy()
f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H))
vis_tcl_img[:, :, 1] = f_tcl_resized * 255
cv2.imwrite(tcl_file_name, vis_tcl_img)
# visualization char maps
vis_char_img = src_image.copy()
# CHW -> HWC
char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg')
f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32')
f_chars[f_chars < 95] = 1.0
f_chars[f_chars == 95] = 0.0
f_chars_resized = cv2.resize(f_chars, dsize=(W, H))
vis_char_img[:, :, 1] = f_chars_resized * 255
cv2.imwrite(char_file_name, vis_char_img)
def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir,
result_path):
"""
"""
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
# draw gt bbox on the image.
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
tokens = line.strip().split('\t')
coords = tokens[0].split(',')
coords_len = len(coords)
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len / 2, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 255, 255),
thickness=1)
for point, point_pair in zip(point_list, point_pair_list):
cv2.line(
vis_det_img,
tuple(point_pair[0]),
tuple(point_pair[1]), (0, 255, 255),
thickness=1)
cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255))
cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0))
cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0))
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix),
vis_det_img)
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_min(im, max_side_len=512):
"""
"""
print('--> Using resize_image_min')
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_for_totaltext(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
# Fix the longer side
# if resize_h > resize_w:
# ratio = float(max_side_len) / resize_h
# else:
# ratio = float(max_side_len) / resize_w
###
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
# constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x**2, axis=axis))
return np.sqrt(np.sum(x**2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def generate_direction_info(image_fn,
H,
W,
ratio_h,
ratio_w,
max_length=640,
out_scale=4,
gt_dir=None):
"""
"""
im_basename = os.path.basename(image_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3])
if gt_dir is None:
gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for label_idx, line in enumerate(lines, start=1):
coords, txt = line.strip().split('\t')
if txt == '###':
continue
tokens = coords.strip().split(',')
coords = list(map(float, tokens))
poly = np.array(coords).reshape(4, 2) * np.array(
[ratio_w, ratio_h]).reshape(1, 2) / out_scale
mid_idx = poly.shape[0] // 2
direct_vector = (
(poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0
direct_vector /= len(txt)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance = 1.0
direct_label = (direct_vector[0], direct_vector[1], avg_char_distance)
cv2.fillPoly(instance_direction_map,
poly.round().astype(np.int32)[np.newaxis, :, :],
direct_label)
instance_direction_map = instance_direction_map.transpose([2, 0, 1])
return instance_direction_map[:2, ...]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
if len(dt_boxes) > 0:
src_im = img
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(src_im, str, org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.7, color=(0, 255, 0), thickness=1)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The e2e Image saved in {}".format(save_path))
def main():
global_config = config['Global']
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'])
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
with open(save_res_path, "wb") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
# write resule
dt_boxes_json = []
for poly, str in zip(points, strs):
tmp_json = {"transcription": str}
tmp_json['points'] = poly.tolist()
dt_boxes_json.append(tmp_json)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
draw_e2e_res(points, strs, config, src_img, file)
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
\ No newline at end of file
......@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser):
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
......@@ -374,7 +375,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PG'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment