"web/vscode:/vscode.git/clone" did not exist on "4d41bd595c1e2bf55f9e3ccee0921b1213c0184a"
Commit b3d6785d authored by myhloli's avatar myhloli
Browse files

refactor(ocr): remove unused code and simplify model architecture

- Remove unused imports and code
- Simplify model architecture by removing unnecessary components
- Update initialization and forward pass logic
- Rename variables for consistency
parent 3cb156f5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, logits_map,
keep_blank_in_idxs=True):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
decoder_results.append(res)
return decoder_results
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def generate_pivot_list_curved(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_expand=True,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
pred_strs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
pred_strs.append(decoded_str)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return pred_strs, instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
# add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_slow(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from .extract_textpoint_slow import *
from .extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
shape_list):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.outs_dict = outs_dict
self.shape_list = shape_list
def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, torch.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
instance_yxs_list, seq_strs = generate_pivot_list_fast(
p_score,
p_char,
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
def pg_postprocess_slow(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, torch.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext"
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import time
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_min(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_for_totaltext(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x**2, axis=axis))
return np.sqrt(np.sum(x**2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
import os
import sys
import logging
import functools
import torch.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# if dist.get_rank() == 0:
# logger.setLevel(log_level)
# else:
# logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
\ No newline at end of file
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
def points2polygon(points):
"""Convert k points to 1 polygon.
Args:
points (ndarray or list): A ndarray or a list of shape (2k)
that indicates k points.
Returns:
polygon (Polygon): A polygon object.
"""
if isinstance(points, list):
points = np.array(points)
assert isinstance(points, np.ndarray)
assert (points.size % 2 == 0) and (points.size >= 8)
point_mat = points.reshape([-1, 2])
return Polygon(point_mat)
def poly_intersection(poly_det, poly_gt, buffer=0.0001):
"""Calculate the intersection area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
intersection_area (float): The intersection area between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
if buffer == 0:
poly_inter = poly_det & poly_gt
else:
poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
return poly_inter.area, poly_inter
def poly_union(poly_det, poly_gt):
"""Calculate the union area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
union_area (float): The union area between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
area_det = poly_det.area
area_gt = poly_gt.area
area_inters, _ = poly_intersection(poly_det, poly_gt)
return area_det + area_gt - area_inters
def valid_boundary(x, with_score=True):
num = len(x)
if num < 8:
return False
if num % 2 == 0 and (not with_score):
return True
if num % 2 == 1 and with_score:
return True
return False
def boundary_iou(src, target):
"""Calculate the IOU between two boundaries.
Args:
src (list): Source boundary.
target (list): Target boundary.
Returns:
iou (float): The iou between two boundaries.
"""
assert valid_boundary(src, False)
assert valid_boundary(target, False)
src_poly = points2polygon(src)
target_poly = points2polygon(target)
return poly_iou(src_poly, target_poly)
def poly_iou(poly_det, poly_gt):
"""Calculate the IOU between two polygons.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
Returns:
iou (float): The IOU between two polygons.
"""
assert isinstance(poly_det, Polygon)
assert isinstance(poly_gt, Polygon)
area_inters, _ = poly_intersection(poly_det, poly_gt)
area_union = poly_union(poly_det, poly_gt)
if area_union == 0:
return 0.0
return area_inters / area_union
def poly_nms(polygons, threshold):
assert isinstance(polygons, list)
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
keep_poly = []
index = [i for i in range(polygons.shape[0])]
while len(index) > 0:
keep_poly.append(polygons[index[-1]].tolist())
A = polygons[index[-1]][:-1]
index = np.delete(index, -1)
iou_list = np.zeros((len(index), ))
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
remove_index = np.where(iou_list > threshold)
index = np.delete(index, remove_index)
return keep_poly
import os
import imghdr
import cv2
import logging
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if imghdr.what(file_path) in img_end:
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
return imgs_lists
def check_and_read_gif(img_path):
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
gif = cv2.VideoCapture(img_path)
ret, frame = gif.read()
if not ret:
# logger = logging.getLogger('ppocr')
print("Cannot read {}. This gif image maybe corrupted.")
# logger.info("Cannot read {}. This gif image maybe corrupted.")
return None, False
if len(frame.shape) == 2 or frame.shape[-1] == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1]
return imgvalue, True
return None, False
def check_and_read(img_path):
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
gif = cv2.VideoCapture(img_path)
ret, frame = gif.read()
if not ret:
logger = logging.getLogger('ppocr')
logger.info("Cannot read {}. This gif image maybe corrupted.")
return None, False
if len(frame.shape) == 2 or frame.shape[-1] == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1]
return imgvalue, True, False
elif os.path.basename(img_path)[-3:] in ['pdf']:
import fitz
from PIL import Image
imgs = []
with fitz.open(img_path) as pdf:
for pg in range(0, pdf.pageCount):
page = pdf[pg]
mat = fitz.Matrix(2, 2)
pm = page.getPixmap(matrix=mat, alpha=False)
# if width or height > 2000 pixels, don't enlarge the image
if pm.width > 2000 or pm.height > 2000:
pm = page.getPixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
imgs.append(img)
return imgs, False, True
return None, False, False
\ No newline at end of file
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import math import math
import time import time
import torch import torch
from pytorchocr.base_ocr_v20 import BaseOCRV20 from ...pytorchocr.base_ocr_v20 import BaseOCRV20
import tools.infer.pytorchocr_utility as utility from . import pytorchocr_utility as utility
from pytorchocr.postprocess import build_post_process from ...pytorchocr.postprocess import build_post_process
from pytorchocr.utils.utility import get_image_file_list, check_and_read_gif
class TextClassifier(BaseOCRV20): class TextClassifier(BaseOCRV20):
...@@ -34,7 +26,8 @@ class TextClassifier(BaseOCRV20): ...@@ -34,7 +26,8 @@ class TextClassifier(BaseOCRV20):
self.weights_path = args.cls_model_path self.weights_path = args.cls_model_path
self.yaml_path = args.cls_yaml_path self.yaml_path = args.cls_yaml_path
network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path) # network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path)
network_config = utility.get_arch_config(self.weights_path)
super(TextClassifier, self).__init__(network_config, **kwargs) super(TextClassifier, self).__init__(network_config, **kwargs)
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
...@@ -46,7 +39,7 @@ class TextClassifier(BaseOCRV20): ...@@ -46,7 +39,7 @@ class TextClassifier(BaseOCRV20):
self.net.eval() self.net.eval()
# if self.use_gpu: # if self.use_gpu:
# self.net.cuda() # self.net.cuda()
self.net = self.net.to(self.device) self.net.to(self.device)
def resize_norm_img(self, img): def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape imgC, imgH, imgW = self.cls_image_shape
...@@ -119,38 +112,3 @@ class TextClassifier(BaseOCRV20): ...@@ -119,38 +112,3 @@ class TextClassifier(BaseOCRV20):
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1) img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res, elapse return img_list, cls_res, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_classifier = TextClassifier(args)
valid_image_file_list = []
img_list = []
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
print("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
except:
print(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
ino]))
print("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time))
if __name__ == '__main__':
main(utility.parse_args())
\ No newline at end of file
import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
import cv2
import numpy as np import numpy as np
import time import time
import json
import torch import torch
from pytorchocr.base_ocr_v20 import BaseOCRV20 from ...pytorchocr.base_ocr_v20 import BaseOCRV20
import tools.infer.pytorchocr_utility as utility from . import pytorchocr_utility as utility
from pytorchocr.utils.utility import get_image_file_list, check_and_read_gif from ...pytorchocr.data import create_operators, transform
from pytorchocr.data import create_operators, transform from ...pytorchocr.postprocess import build_post_process
from pytorchocr.postprocess import build_post_process
class TextDetector(BaseOCRV20): class TextDetector(BaseOCRV20):
...@@ -123,13 +114,14 @@ class TextDetector(BaseOCRV20): ...@@ -123,13 +114,14 @@ class TextDetector(BaseOCRV20):
self.weights_path = args.det_model_path self.weights_path = args.det_model_path
self.yaml_path = args.det_yaml_path self.yaml_path = args.det_yaml_path
network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path) # network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path)
network_config = utility.get_arch_config(self.weights_path)
super(TextDetector, self).__init__(network_config, **kwargs) super(TextDetector, self).__init__(network_config, **kwargs)
self.load_pytorch_weights(self.weights_path) self.load_pytorch_weights(self.weights_path)
self.net.eval() self.net.eval()
# if self.use_gpu: # if self.use_gpu:
# self.net.cuda() # self.net.cuda()
self.net = self.net.to(self.device) self.net.to(self.device)
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
...@@ -231,38 +223,3 @@ class TextDetector(BaseOCRV20): ...@@ -231,38 +223,3 @@ class TextDetector(BaseOCRV20):
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, elapse return dt_boxes, elapse
if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args)
count = 0
total_time = 0
draw_img_save = "./inference_results"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
print("error in loading image:{}".format(image_file))
continue
dt_boxes, elapse = text_detector(img)
if count > 0:
total_time += elapse
count += 1
save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps(np.array(dt_boxes).astype(np.int32).tolist())) + "\n"
print(save_pred)
print("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
print("The visualized image saved in {}".format(img_path))
if count > 1:
print("Avg Time: {}".format(total_time / (count - 1)))
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
from PIL import Image from PIL import Image
import cv2 import cv2
import numpy as np import numpy as np
import math import math
import time import time
import torch import torch
from pytorchocr.base_ocr_v20 import BaseOCRV20 from ...pytorchocr.base_ocr_v20 import BaseOCRV20
import tools.infer.pytorchocr_utility as utility from . import pytorchocr_utility as utility
from pytorchocr.postprocess import build_post_process from ...pytorchocr.postprocess import build_post_process
from pytorchocr.utils.utility import get_image_file_list, check_and_read_gif
class TextRecognizer(BaseOCRV20): class TextRecognizer(BaseOCRV20):
...@@ -87,7 +80,8 @@ class TextRecognizer(BaseOCRV20): ...@@ -87,7 +80,8 @@ class TextRecognizer(BaseOCRV20):
self.yaml_path = args.rec_yaml_path self.yaml_path = args.rec_yaml_path
char_num = len(getattr(self.postprocess_op, 'character')) char_num = len(getattr(self.postprocess_op, 'character'))
network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path, char_num) # network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path, char_num)
network_config = utility.get_arch_config(self.weights_path)
weights = self.read_pytorch_weights(self.weights_path) weights = self.read_pytorch_weights(self.weights_path)
self.out_channels = self.get_out_channels(weights) self.out_channels = self.get_out_channels(weights)
...@@ -103,7 +97,7 @@ class TextRecognizer(BaseOCRV20): ...@@ -103,7 +97,7 @@ class TextRecognizer(BaseOCRV20):
self.net.eval() self.net.eval()
# if self.use_gpu: # if self.use_gpu:
# self.net.cuda() # self.net.cuda()
self.net = self.net.to(self.device) self.net.to(self.device)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
...@@ -452,33 +446,3 @@ class TextRecognizer(BaseOCRV20): ...@@ -452,33 +446,3 @@ class TextRecognizer(BaseOCRV20):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime elapse += time.time() - starttime
return rec_res, elapse return rec_res, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args)
valid_image_file_list = []
img_list = []
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
print("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
exit()
for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
ino]))
print("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time))
if __name__ == '__main__':
main(utility.parse_args())
\ No newline at end of file
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import time
from PIL import Image
import tools.infer.pytorchocr_utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from pytorchocr.utils.utility import get_image_file_list, check_and_read_gif
from tools.infer.pytorchocr_utility import draw_ocr_box_txt
from . import predict_rec
from . import predict_det
from . import predict_cls
class TextSystem(object): class TextSystem(object):
...@@ -121,51 +109,3 @@ def sorted_boxes(dt_boxes): ...@@ -121,51 +109,3 @@ def sorted_boxes(dt_boxes):
_boxes[i] = _boxes[i + 1] _boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
return _boxes return _boxes
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
drop_score = args.drop_score
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
print("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
print("{}, {:.3f}".format(text, score))
if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
boxes = dt_boxes
txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr_box_txt(
image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
if __name__ == '__main__':
main(utility.parse_args())
\ No newline at end of file
import os, sys import os
import math import math
from pathlib import Path
import numpy as np import numpy as np
import cv2 import cv2
from PIL import Image, ImageDraw, ImageFont
import argparse import argparse
root_dir = Path(__file__).resolve().parent.parent.parent
DEFAULT_CFG_PATH = root_dir / "pytorchocr" / "modeling" / "arch_config.yaml"
def init_args(): def init_args():
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
...@@ -64,7 +69,7 @@ def init_args(): ...@@ -64,7 +69,7 @@ def init_args():
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_path", type=str) parser.add_argument("--rec_model_path", type=str)
parser.add_argument("--rec_image_inverse", type=str2bool, default=True) parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
...@@ -165,215 +170,6 @@ def AnalysisConfig(weights_path, yaml_path=None, char_num=None): ...@@ -165,215 +170,6 @@ def AnalysisConfig(weights_path, yaml_path=None, char_num=None):
if yaml_path is not None: if yaml_path is not None:
return read_network_config_from_yaml(yaml_path, char_num=char_num) return read_network_config_from_yaml(yaml_path, char_num=char_num)
weights_basename = os.path.basename(weights_path)
weights_name = weights_basename.lower()
# supported_weights = ['ch_ptocr_server_v2.0_det_infer.pth',
# 'ch_ptocr_server_v2.0_rec_infer.pth',
# 'ch_ptocr_mobile_v2.0_det_infer.pth',
# 'ch_ptocr_mobile_v2.0_rec_infer.pth',
# 'ch_ptocr_mobile_v2.0_cls_infer.pth',
# ]
# assert weights_name in supported_weights, \
# "supported weights are {} but input weights is {}".format(supported_weights, weights_name)
if weights_name == 'ch_ptocr_server_v2.0_det_infer.pth':
network_config = {'model_type':'det',
'algorithm':'DB',
'Transform':None,
'Backbone':{'name':'ResNet_vd', 'layers':18, 'disable_se':True},
'Neck':{'name':'DBFPN', 'out_channels':256},
'Head':{'name':'DBHead', 'k':50}}
elif weights_name == 'ch_ptocr_server_v2.0_rec_infer.pth':
network_config = {'model_type':'rec',
'algorithm':'CRNN',
'Transform':None,
'Backbone':{'name':'ResNet', 'layers':34},
'Neck':{'name':'SequenceEncoder', 'hidden_size':256, 'encoder_type':'rnn'},
'Head':{'name':'CTCHead', 'fc_decay': 4e-05}}
elif weights_name in ['ch_ptocr_mobile_v2.0_det_infer.pth']:
network_config = {'model_type': 'det',
'algorithm': 'DB',
'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'model_name': 'large', 'scale': 0.5, 'disable_se': True},
'Neck': {'name': 'DBFPN', 'out_channels': 96},
'Head': {'name': 'DBHead', 'k': 50}}
elif weights_name =='ch_ptocr_mobile_v2.0_rec_infer.pth':
network_config = {'model_type':'rec',
'algorithm':'CRNN',
'Transform':None,
'Backbone':{'model_name':'small', 'name':'MobileNetV3', 'scale':0.5, 'small_stride':[1,2,2,2]},
'Neck':{'name':'SequenceEncoder', 'hidden_size':48, 'encoder_type':'rnn'},
'Head':{'name':'CTCHead', 'fc_decay': 4e-05}}
elif weights_name == 'ch_ptocr_mobile_v2.0_cls_infer.pth':
network_config = {'model_type':'cls',
'algorithm':'CLS',
'Transform':None,
'Backbone':{'name':'MobileNetV3', 'model_name':'small', 'scale':0.35},
'Neck':None,
'Head':{'name':'ClsHead', 'class_dim':2}}
elif weights_name == 'ch_ptocr_v2_rec_infer.pth':
network_config = {'model_type': 'rec',
'algorithm': 'CRNN',
'Transform': None,
'Backbone': {'name': 'MobileNetV1Enhance', 'scale': 0.5},
'Neck': {'name': 'SequenceEncoder', 'hidden_size': 64, 'encoder_type': 'rnn'},
'Head': {'name': 'CTCHead', 'mid_channels': 96, 'fc_decay': 2e-05}}
elif weights_name == 'ch_ptocr_v2_det_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'DB',
'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'model_name': 'large', 'scale': 0.5, 'disable_se': True},
'Neck': {'name': 'DBFPN', 'out_channels': 96},
'Head': {'name': 'DBHead', 'k': 50}}
elif weights_name == 'ch_ptocr_v3_rec_infer.pth':
network_config = {'model_type':'rec',
'algorithm':'CRNN',
'Transform':None,
'Backbone':{'name':'MobileNetV1Enhance',
'scale':0.5,
'last_conv_stride': [1, 2],
'last_pool_type': 'avg'},
'Neck':{'name':'SequenceEncoder',
'dims': 64,
'depth': 2,
'hidden_dims': 120,
'use_guide': True,
'encoder_type':'svtr'},
'Head':{'name':'CTCHead', 'fc_decay': 2e-05}
}
elif weights_name == 'ch_ptocr_v3_det_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'DB',
'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'model_name': 'large', 'scale': 0.5, 'disable_se': True},
'Neck': {'name': 'RSEFPN', 'out_channels': 96, 'shortcut': True},
'Head': {'name': 'DBHead', 'k': 50}}
elif weights_name == 'det_mv3_db_v2.0_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'DB',
'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'model_name': 'large'},
'Neck': {'name': 'DBFPN', 'out_channels': 256},
'Head': {'name': 'DBHead', 'k': 50}}
elif weights_name == 'det_r50_vd_db_v2.0_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'DB',
'Transform': None,
'Backbone': {'name': 'ResNet_vd', 'layers': 50},
'Neck': {'name': 'DBFPN', 'out_channels': 256},
'Head': {'name': 'DBHead', 'k': 50}}
elif weights_name == 'det_mv3_east_v2.0_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'EAST',
'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'model_name': 'large'},
'Neck': {'name': 'EASTFPN', 'model_name': 'small'},
'Head': {'name': 'EASTHead', 'model_name': 'small'}}
elif weights_name == 'det_r50_vd_east_v2.0_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'EAST',
'Transform': None,
'Backbone': {'name': 'ResNet_vd', 'layers': 50},
'Neck': {'name': 'EASTFPN', 'model_name': 'large'},
'Head': {'name': 'EASTHead', 'model_name': 'large'}}
elif weights_name == 'det_r50_vd_sast_icdar15_v2.0_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'SAST',
'Transform': None,
'Backbone': {'name': 'ResNet_SAST', 'layers': 50},
'Neck': {'name': 'SASTFPN', 'with_cab': True},
'Head': {'name': 'SASTHead'}}
elif weights_name == 'det_r50_vd_sast_totaltext_v2.0_infer.pth':
network_config = {'model_type': 'det',
'algorithm': 'SAST',
'Transform': None,
'Backbone': {'name': 'ResNet_SAST', 'layers': 50},
'Neck': {'name': 'SASTFPN', 'with_cab': True},
'Head': {'name': 'SASTHead'}}
elif weights_name == 'en_server_pgneta_infer.pth':
network_config = {'model_type': 'e2e',
'algorithm': 'PGNet',
'Transform': None,
'Backbone': {'name': 'ResNet', 'layers': 50},
'Neck': {'name': 'PGFPN'},
'Head': {'name': 'PGHead'}}
elif weights_name == 'en_ptocr_mobile_v2.0_table_det_infer.pth':
network_config = {'model_type': 'det','algorithm': 'DB',
'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'model_name': 'large', 'scale': 0.5, 'disable_se': False},
'Neck': {'name': 'DBFPN', 'out_channels': 96},
'Head': {'name': 'DBHead', 'k': 50}}
elif weights_name == 'en_ptocr_mobile_v2.0_table_rec_infer.pth':
network_config = {'model_type': 'rec',
'algorithm': 'CRNN',
'Transform': None,
'Backbone': {'model_name': 'large', 'name': 'MobileNetV3', },
'Neck': {'name': 'SequenceEncoder', 'hidden_size': 96, 'encoder_type': 'rnn'},
'Head': {'name': 'CTCHead', 'fc_decay': 4e-05}}
elif 'om_' in weights_name and '_rec_' in weights_name:
network_config = {'model_type': 'rec',
'algorithm': 'CRNN',
'Transform': None,
'Backbone': {'model_name': 'small', 'name': 'MobileNetV3', 'scale': 0.5,
'small_stride': [1, 2, 2, 2]},
'Neck': {'name': 'SequenceEncoder', 'hidden_size': 48, 'encoder_type': 'om'},
'Head': {'name': 'CTCHead', 'fc_decay': 4e-05}}
else:
network_config = {'model_type': 'rec',
'algorithm': 'CRNN',
'Transform': None,
'Backbone': {'model_name': 'small', 'name': 'MobileNetV3', 'scale': 0.5,
'small_stride': [1, 2, 2, 2]},
'Neck': {'name': 'SequenceEncoder', 'hidden_size': 48, 'encoder_type': 'rnn'},
'Head': {'name': 'CTCHead', 'fc_decay': 4e-05}}
# raise NotImplementedError
return network_config
def draw_e2e_res(dt_boxes, strs, img_path):
src_im = cv2.imread(img_path)
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
return src_im
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
return src_im
def resize_img(img, input_size=600): def resize_img(img, input_size=600):
""" """
...@@ -387,58 +183,6 @@ def resize_img(img, input_size=600): ...@@ -387,58 +183,6 @@ def resize_img(img, input_size=600):
return img return img
def draw_ocr_box_txt(image,
boxes,
txts,
scores=None,
drop_score=0.5,
font_path="./doc/simfang.ttf"):
h, w = image.height, image.width
img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255))
import random
random.seed(0)
draw_left = ImageDraw.Draw(img_left)
draw_right = ImageDraw.Draw(img_right)
for idx, (box, txt) in enumerate(zip(boxes, txts)):
if scores is not None and scores[idx] < drop_score:
continue
color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255))
draw_left.polygon(box, fill=color)
draw_right.polygon(
[
box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
box[2][1], box[3][0], box[3][1]
],
outline=color)
box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
1])**2)
box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
1])**2)
if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
cur_y = box[0][1]
for c in txt:
char_size = font.getsize(c)
draw_right.text(
(box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
cur_y += char_size[1]
else:
font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
draw_right.text(
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
return np.array(img_show)
def str_count(s): def str_count(s):
""" """
Count the number of Chinese characters, Count the number of Chinese characters,
...@@ -463,82 +207,6 @@ def str_count(s): ...@@ -463,82 +207,6 @@ def str_count(s):
return s_len - math.ceil(en_dg_count / 2) return s_len - math.ceil(en_dg_count / 2)
def text_visual(texts,
scores,
img_h=400,
img_w=600,
threshold=0.,
font_path="./doc/simfang.ttf"):
"""
create new blank img and draw txt on it
args:
texts(list): the text will be draw
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
"""
if scores is not None:
assert len(texts) == len(
scores), "The number of txts and corresponding scores must match"
def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
blank_img[:, img_w - 1:] = 0
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
return blank_img, draw_txt
blank_img, draw_txt = create_blank_img()
font_size = 20
txt_color = (0, 0, 0)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
gap = font_size + 5
txt_img_list = []
count, index = 1, 0
for idx, txt in enumerate(texts):
index += 1
if scores[idx] < threshold or math.isnan(scores[idx]):
index -= 1
continue
first_line = True
while str_count(txt) >= img_w // font_size - 4:
tmp = txt
txt = tmp[:img_w // font_size - 4]
if first_line:
new_txt = str(index) + ': ' + txt
first_line = False
else:
new_txt = ' ' + txt
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
txt = tmp[img_w // font_size - 4:]
if count >= img_h // gap - 1:
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
if first_line:
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
else:
new_txt = " " + txt + " " + '%.3f' % (scores[idx])
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
# whether add new blank img or not
if count >= img_h // gap - 1 and idx + 1 < len(texts):
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
txt_img_list.append(np.array(blank_img))
if len(txt_img_list) == 1:
blank_img = np.array(txt_img_list[0])
else:
blank_img = np.concatenate(txt_img_list, axis=1)
return np.array(blank_img)
def base64_to_cv2(b64str): def base64_to_cv2(b64str):
import base64 import base64
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
...@@ -547,12 +215,13 @@ def base64_to_cv2(b64str): ...@@ -547,12 +215,13 @@ def base64_to_cv2(b64str):
return data return data
def draw_boxes(image, boxes, scores=None, drop_score=0.5): def get_arch_config(model_path):
if scores is None: from omegaconf import OmegaConf
scores = [1] * len(boxes) all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
for (box, score) in zip(boxes, scores): path = Path(model_path)
if score < drop_score: file_name = path.stem
continue if file_name not in all_arch_config:
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
return image arch_config = all_arch_config[file_name]
\ No newline at end of file return arch_config
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment