Commit a7a899f6 authored by myhloli's avatar myhloli
Browse files

feat(model): add OCR model base structure and utilities

- Add base model structure for OCR in pytorch
- Implement data augmentation and transformation modules
- Create utilities for dictionary handling and state dict conversion
- Include post-processing modules for OCR
- Add weight initialization and loading functions
parent 72e66c2d
u
g
_
i
m
/
1
I
L
S
V
R
C
2
0
v
a
l
8
5
3
6
9
.
j
p
ق
ا
پ
ل
4
7
ئ
ى
ش
ت
ي
ك
د
ف
ر
و
ن
ب
ە
خ
ې
چ
ۇ
ز
س
م
ۋ
گ
ڭ
ۆ
ۈ
ج
غ
ھ
ژ
s
c
e
n
w
P
E
D
U
d
r
b
y
B
o
O
Y
N
T
k
t
h
A
H
F
z
W
K
G
M
f
Z
X
Q
J
x
q
-
!
%
#
?
:
$
,
&
'
É
@
é
(
+
u
k
_
i
m
g
/
1
6
I
L
S
V
R
C
2
0
v
a
l
7
9
.
j
p
в
і
д
п
о
н
с
т
ю
4
5
3
а
и
м
е
р
ч
у
Б
з
л
к
8
А
В
г
є
б
ь
х
ґ
ш
ц
ф
я
щ
ж
Г
Х
У
Т
Е
І
Н
П
З
Л
Ю
С
Д
М
К
Р
Ф
О
Ц
И
Я
Ч
Ш
Ж
Є
Ґ
Ь
s
c
e
n
w
A
P
r
E
t
o
h
d
y
M
G
N
F
B
T
D
U
O
W
Z
f
H
Y
b
K
z
x
Q
X
q
J
$
-
'
#
&
%
?
:
!
,
+
@
(
é
É
u
r
_
i
m
g
/
3
I
L
S
V
R
C
2
0
1
v
a
l
9
7
8
.
j
p
چ
ٹ
پ
ا
ئ
ی
ے
4
6
و
ل
ن
ڈ
ھ
ک
ت
ش
ف
ق
ر
د
5
ب
ج
خ
ہ
س
ز
غ
ڑ
ں
آ
م
ؤ
ط
ص
ح
ع
گ
ث
ض
ذ
ۓ
ِ
ء
ظ
ً
ي
ُ
ۃ
أ
ٰ
ە
ژ
ۂ
ة
ّ
ك
ه
s
c
e
n
w
o
d
t
D
M
T
U
E
b
P
h
y
W
H
A
x
B
O
N
G
Y
Q
F
k
K
q
J
Z
f
z
X
'
@
&
!
,
:
$
-
#
?
%
é
+
(
É
x
i
_
m
g
/
1
0
I
L
S
V
R
C
2
v
a
l
3
6
4
5
.
j
p
Q
u
e
r
o
8
7
n
c
9
t
b
é
q
d
ó
y
F
s
,
O
í
T
f
"
U
M
h
:
P
H
A
E
D
z
N
á
ñ
ú
%
;
è
+
Y
-
B
G
(
)
¿
?
w
¡
!
X
É
K
k
Á
ü
Ú
«
»
J
'
ö
W
Z
º
Ö
­
[
]
Ç
ç
à
ä
û
ò
Í
ê
ô
ø
ª
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
_
`
~
\ No newline at end of file
raise ValueError('utils -> e2e_utils -> extract_batchsize')
import paddle
import numpy as np
import copy
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
"""
"""
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
ngpu = int(batch_size / img_bs)
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
for i in range(ngpu):
pos_lists_split.append([])
pos_masks_split.append([])
label_lists_split.append([])
for i in range(img_ids.shape[0]):
img_id = img_ids[i]
gpu_id = int(img_id / img_bs)
img_id = img_id % img_bs
pos_list = pos_lists[i].copy()
pos_list[:, 0] = img_id
pos_lists_split[gpu_id].append(pos_list)
pos_masks_split[gpu_id].append(pos_masks[i].copy())
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
# repeat or delete
for i in range(ngpu):
vp_len = len(pos_lists_split[i])
if vp_len <= tcl_bs:
for j in range(0, tcl_bs - vp_len):
pos_list = pos_lists_split[i][j].copy()
pos_lists_split[i].append(pos_list)
pos_mask = pos_masks_split[i][j].copy()
pos_masks_split[i].append(pos_mask)
label_list = copy.deepcopy(label_lists_split[i][j])
label_lists_split[i].append(label_list)
else:
for j in range(0, vp_len - tcl_bs):
c_len = len(pos_lists_split[i])
pop_id = np.random.permutation(c_len)[0]
pos_lists_split[i].pop(pop_id)
pos_masks_split[i].pop(pop_id)
label_lists_split[i].pop(pop_id)
# merge
for i in range(ngpu):
pos_lists_.extend(pos_lists_split[i])
pos_masks_.extend(pos_masks_split[i])
label_lists_.extend(label_lists_split[i])
return pos_lists_, pos_masks_, label_lists_
def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
pad_num, tcl_bs):
label_list = label_list.numpy()
batch, _, _, _ = label_list.shape
pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy()
pos_list_t = []
pos_mask_t = []
label_list_t = []
for i in range(batch):
for j in range(max_text_nums):
if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
label_list_t, tcl_bs)
label = []
tt = [l.tolist() for l in label_list]
for i in range(tcl_bs):
k = 0
for j in range(max_text_length):
if tt[i][j][0] != pad_num:
k += 1
else:
break
label.append(k)
label = paddle.to_tensor(label)
label = paddle.cast(label, dtype='int64')
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
label_list = paddle.cast(label_list, dtype='int32')
return pos_list, pos_mask, label_list, label
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)]
probs_seq = logits_seq
labels = np.argmax(probs_seq, axis=1)
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
detal = len(gather_info) // (pts_num - 1)
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list,
logits_map,
Lexicon_Table,
pts_num=6):
"""
CTC decoder using multiple processes.
"""
decoder_str = []
decoder_xys = []
for gather_info in gather_info_list:
if len(gather_info) < pts_num:
continue
dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num)
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
decoder_str.append(dst_str_readable)
decoder_xys.append(xys_list)
return decoder_str, decoder_xys
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
src_h, valid_set):
poly_list = []
keep_str_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
offset_expand = 1.0
if valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
detected_poly = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
keep_str_list.append(keep_str)
if valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
return poly_list, keep_str_list
def generate_pivot_list_fast(p_score,
p_char_maps,
f_direction,
Lexicon_Table,
score_thresh=0.5):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
return keep_yxs_list, decoded_str
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, logits_map,
keep_blank_in_idxs=True):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
decoder_results.append(res)
return decoder_results
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def generate_pivot_list_curved(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_expand=True,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
pred_strs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
pred_strs.append(decoded_str)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return pred_strs, instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
# add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_slow(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from .extract_textpoint_slow import *
from .extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
shape_list):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.outs_dict = outs_dict
self.shape_list = shape_list
def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, torch.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
instance_yxs_list, seq_strs = generate_pivot_list_fast(
p_score,
p_char,
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
def pg_postprocess_slow(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, torch.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext"
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment