Commit 5713e0ca authored by yangql's avatar yangql
Browse files

Initial commit

parents
Pipeline #501 canceled with stages
import sys
import cv2
import numpy as np
import pyclipper
import six
from shapely.geometry import Polygon
class DecodeImage:
"""decode image"""
def __init__(self, img_mode="RGB", channel_first=False):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data["image"]
if six.PY2:
assert (
type(img) is str and len(img) > 0
), "invalid input 'img' in DecodeImage"
else:
assert (
type(img) is bytes and len(img) > 0
), "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == "GRAY":
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == "RGB":
assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]"
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data["image"] = img
return data
class NormalizeImage:
"""normalize image such as substract mean, divide std"""
def __init__(self, scale=None, mean=None, std=None, order="chw"):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype("float32")
self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
img = np.array(data["image"]).astype(np.float32)
data["image"] = (img * self.scale - self.mean) / self.std
return data
class ToCHWImage:
"""convert hwc image to chw image"""
def __init__(self):
pass
def __call__(self, data):
img = np.array(data["image"])
data["image"] = img.transpose((2, 0, 1))
return data
class KeepKeys:
def __init__(self, keep_keys):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if "image_shape" in kwargs:
self.image_shape = kwargs["image_shape"]
self.resize_type = 1
elif "limit_side_len" in kwargs:
self.limit_side_len = kwargs.get("limit_side_len", 736)
self.limit_type = kwargs.get("limit_type", "min")
if "resize_long" in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get("resize_long", 960)
else:
self.limit_side_len = kwargs.get("limit_side_len", 736)
self.limit_type = kwargs.get("limit_type", "min")
def __call__(self, data):
img = data["image"]
src_h, src_w = img.shape[:2]
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data["image"] = img
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w = img.shape[:2]
# limit the max side
if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.0
else:
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.0
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w = img.shape[:2]
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def transform(data, ops=None):
"""transform"""
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_dict):
"""
create operators based on the config
"""
ops = []
for op_name, param in op_param_dict.items():
if param is None:
param = {}
op = eval(op_name)(**param)
ops.append(op)
return ops
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
return src_im
class DBPostProcess:
"""The post process for Differentiable Binarization (DB)."""
def __init__(
self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
score_mode="fast",
use_dilation=False,
):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
if use_dilation:
self.dilation_kernel = np.array([[1, 1], [1, 1]])
else:
self.dilation_kernel = None
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours(
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height
)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
"""
box_score_slow: use polyon mean score as the mean score
"""
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def __call__(self, pred, shape_list):
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel,
)
else:
mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(
pred[batch_index], mask, src_w, src_h
)
boxes_batch.append({"points": boxes})
return boxes_batch
from .text_recognize import TextRecognizer
#模型信息
rec_img_shape: [3, 48, 320]
rec_batch_num: 6
import argparse
import math
import time
from typing import List
import cv2
import numpy as np
from rapidocr_onnxruntime.utils import OrtInferSession, read_yaml
from .utils import CTCLabelDecode
class TextRecognizer:
def __init__(self, config):
self.session = OrtInferSession(config)
if self.session.have_key():
self.character_dict_path = self.session.get_character_list()
else:
self.character_dict_path = config.get("keys_path", None)
self.postprocess_op = CTCLabelDecode(self.character_dict_path)
self.rec_batch_num = config["rec_batch_num"]
self.rec_image_shape = config["rec_img_shape"]
def __call__(self, img_list: List[np.ndarray]):
if isinstance(img_list, np.ndarray):
img_list = [img_list]
# Calculate the aspect ratio of all text bars
width_list = [img.shape[1] / float(img.shape[0]) for img in img_list]
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
rec_res = [["", 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
norm_img_batch = []
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img_batch.append(norm_img[np.newaxis, :])
norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32)
starttime = time.time()
preds = self.session(norm_img_batch)[0]
rec_result = self.postprocess_op(preds)
for rno, one_res in enumerate(rec_result):
rec_res[indices[beg_img_no + rno]] = one_res
elapse += time.time() - starttime
return rec_res, elapse
def resize_norm_img(self, img, max_wh_ratio):
img_channel, img_height, img_width = self.rec_image_shape
assert img_channel == img.shape[2]
img_width = int(img_height * max_wh_ratio)
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(img_height * ratio) > img_width:
resized_w = img_width
else:
resized_w = int(math.ceil(img_height * ratio))
resized_image = cv2.resize(img, (resized_w, img_height))
resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image_path", type=str, help="image_dir|image_path")
parser.add_argument("--config_path", type=str, default="config.yaml")
args = parser.parse_args()
config = read_yaml(args.config_path)
text_recognizer = TextRecognizer(config)
img = cv2.imread(args.image_path)
rec_res, predict_time = text_recognizer(img)
print(f"rec result: {rec_res}\t cost: {predict_time}s")
import numpy as np
class CTCLabelDecode:
"""Convert between text-label and text-index"""
def __init__(self, character_dict_path):
super(CTCLabelDecode, self).__init__()
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None"
if isinstance(character_dict_path, str):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
else:
self.character_str = character_dict_path
self.character_str.append(" ")
dict_character = self.add_special_char(self.character_str)
self.character = dict_character
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
def __call__(self, preds, label=None):
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ["blank"] + dict_character
return dict_character
def get_ignored_tokens(self):
return [0] # for ctc blank
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
"""convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if (
idx > 0
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
):
continue
char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = "".join(char_list)
result_list.append((text, np.mean(conf_list + [1e-50])))
return result_list
Global:
text_score: 0.5
use_angle_cls: true
use_text_det: true
print_verbose: false
min_height: 30
width_height_ratio: 8
Det:
model_path: ../../Resource/Models/ch_PP-OCRv3_det_infer.onnx
limit_side_len: 736
limit_type: min
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.6
use_dilation: true
score_mode: fast
Cls:
model_path: ../../Resource/Models/ch_ppocr_mobile_v2.0_cls_infer.onnx
cls_image_shape: [3, 48, 192]
cls_batch_num: 6
cls_thresh: 0.9
label_list: ['0', '180']
Rec:
model_path: ../../Resource/Models/ch_PP-OCRv3_rec_infer.onnx
rec_img_shape: [3, 48, 320]
rec_batch_num: 6
import copy
import importlib
from pathlib import Path
from typing import Optional, Union
import cv2
import numpy as np
from .ch_ppocr_v2_cls import TextClassifier
from .ch_ppocr_v3_det import TextDetector
from .ch_ppocr_v3_rec import TextRecognizer
from .utils import LoadImage, UpdateParameters, concat_model_path, init_args, read_yaml
root_dir = Path(__file__).resolve().parent
class RapidOCR:
def __init__(self, config_path: Optional[str] = None, **kwargs):
if config_path is None:
config_path = str(root_dir / "config.yaml")
if not Path(config_path).exists():
raise FileExistsError(f"{config_path} does not exist!")
config = read_yaml(config_path)
config = concat_model_path(config)
if kwargs:
updater = UpdateParameters()
config = updater(config, **kwargs)
global_config = config["Global"]
self.print_verbose = global_config["print_verbose"]
self.text_score = global_config["text_score"]
self.min_height = global_config["min_height"]
self.width_height_ratio = global_config["width_height_ratio"]
self.use_text_det = config["Global"]["use_text_det"]
if self.use_text_det:
self.text_detector = TextDetector(config["Det"])
self.text_recognizer = TextRecognizer(config["Rec"])
self.use_angle_cls = config["Global"]["use_angle_cls"]
if self.use_angle_cls:
self.text_cls = TextClassifier(config["Cls"])
self.load_img = LoadImage()
def __call__(self, img_content: Union[str, np.ndarray, bytes, Path], **kwargs):
if kwargs:
box_thresh = kwargs.get("box_thresh", 0.5)
unclip_ratio = kwargs.get("unclip_ratio", 1.6)
text_score = kwargs.get("text_score", 0.5)
self.text_detector.postprocess_op.box_thresh = box_thresh
self.text_detector.postprocess_op.unclip_ratio = unclip_ratio
self.text_score = text_score
img = self.load_img(img_content)
h, w = img.shape[:2]
if self.width_height_ratio == -1:
use_limit_ratio = False
else:
use_limit_ratio = w / h > self.width_height_ratio
if not self.use_text_det or h <= self.min_height or use_limit_ratio:
dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w)
det_elapse = 0.0
else:
dt_boxes, det_elapse = self.text_detector(img)
if dt_boxes is None or len(dt_boxes) < 1:
return None, None
if self.print_verbose:
print(f"dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}")
dt_boxes = self.sorted_boxes(dt_boxes)
img_crop_list = self.get_crop_img_list(img, dt_boxes)
cls_elapse = 0.0
if self.use_angle_cls:
img_crop_list, _, cls_elapse = self.text_cls(img_crop_list)
if self.print_verbose:
print(f"cls num: {len(img_crop_list)}, elapse: {cls_elapse}")
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
if self.print_verbose:
print(f"rec_res num: {len(rec_res)}, elapse: {rec_elapse}")
filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, rec_res)
fina_result = [
[dt.tolist(), rec[0], str(rec[1])]
for dt, rec in zip(filter_boxes, filter_rec_res)
]
if fina_result:
return fina_result, [det_elapse, cls_elapse, rec_elapse]
return None, None
@staticmethod
def init_module(module_name, class_name):
module_part = importlib.import_module(module_name)
return getattr(module_part, class_name)
def get_boxes_img_without_det(self, img, h, w):
x0, y0, x1, y1 = 0, 0, w, h
dt_boxes = np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
dt_boxes = dt_boxes[np.newaxis, ...]
img_crop_list = [img]
return dt_boxes, img_crop_list
def get_crop_img_list(self, img, dt_boxes):
def get_rotate_crop_image(img, points):
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3]),
)
)
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2]),
)
)
pts_std = np.float32(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
)
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M,
(img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC,
)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
img_crop_list = []
for box in dt_boxes:
tmp_box = copy.deepcopy(box)
img_crop = get_rotate_crop_image(img, tmp_box)
img_crop_list.append(img_crop)
return img_crop_list
@staticmethod
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if (
abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
and _boxes[j + 1][0][0] < _boxes[j][0][0]
):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def filter_boxes_rec_by_score(self, dt_boxes, rec_res):
filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
if score >= self.text_score:
filter_boxes.append(box)
filter_rec_res.append(rec_reuslt)
return filter_boxes, filter_rec_res
def main():
args = init_args()
ocr_engine = RapidOCR(**vars(args))
result, elapse_list = ocr_engine(args.img_path)
print(result)
if args.print_cost:
print(elapse_list)
if __name__ == "__main__":
main()
import argparse
import traceback
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Union
import os
import cv2
import numpy as np
import yaml
from onnxruntime import (
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
)
from PIL import Image, UnidentifiedImageError
root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path]
class OrtInferSession:
def __init__(self, config):
sess_opt = SessionOptions()
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
self._verify_model(config["model_path"])
print(config["model_path"])
self.session = InferenceSession(
config["model_path"], sess_options=sess_opt, providers=['ROCMExecutionProvider']
)
def __call__(self, input_content: np.ndarray) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), [input_content]))
try:
return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
error_info = traceback.format_exc()
raise ONNXRuntimeError(error_info) from e
def get_input_names(
self,
):
return [v.name for v in self.session.get_inputs()]
def get_output_names(
self,
):
return [v.name for v in self.session.get_outputs()]
def get_character_list(self, key: str = "character"):
return self.meta_dict[key].splitlines()
def have_key(self, key: str = "character") -> bool:
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in self.meta_dict.keys():
return True
return False
@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
class ONNXRuntimeError(Exception):
pass
class LoadImage:
def __init__(
self,
):
pass
def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
f"The img type {type(img)} does not in {InputType.__args__}"
)
img = self.load_img(img)
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.ndim == 3 and img.shape[2] == 4:
return self.cvt_four_to_three(img)
return img
def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, (str, Path)):
self.verify_exist(img)
try:
img = np.array(Image.open(img))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
except UnidentifiedImageError as e:
raise LoadImageError(f"cannot identify image file {img}") from e
return img
if isinstance(img, bytes):
img = np.array(Image.open(BytesIO(img)))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
if isinstance(img, np.ndarray):
return img
raise LoadImageError(f"{type(img)} is not supported!")
@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → RGB"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))
not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img
@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")
class LoadImageError(Exception):
pass
def read_yaml(yaml_path):
with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
def concat_model_path(config):
key = "model_path"
config["Det"][key] = os.path.join(root_dir, config["Det"][key])
config["Rec"][key] = str(root_dir / config["Rec"][key])
config["Cls"][key] = str(root_dir / config["Cls"][key])
return config
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument("-img", "--img_path", type=str, default=None, required=True)
parser.add_argument("-p", "--print_cost", action="store_true", default=False)
global_group = parser.add_argument_group(title="Global")
global_group.add_argument("--text_score", type=float, default=0.5)
global_group.add_argument("--use_angle_cls", type=bool, default=True)
global_group.add_argument("--use_text_det", type=bool, default=True)
global_group.add_argument("--print_verbose", type=bool, default=False)
global_group.add_argument("--min_height", type=int, default=30)
global_group.add_argument("--width_height_ratio", type=int, default=8)
det_group = parser.add_argument_group(title="Det")
det_group.add_argument("--det_model_path", type=str, default=None)
det_group.add_argument("--det_limit_side_len", type=float, default=736)
det_group.add_argument(
"--det_limit_type", type=str, default="min", choices=["max", "min"]
)
det_group.add_argument("--det_thresh", type=float, default=0.3)
det_group.add_argument("--det_box_thresh", type=float, default=0.5)
det_group.add_argument("--det_unclip_ratio", type=float, default=1.6)
det_group.add_argument("--det_use_dilation", type=bool, default=True)
det_group.add_argument(
"--det_score_mode", type=str, default="fast", choices=["slow", "fast"]
)
cls_group = parser.add_argument_group(title="Cls")
cls_group.add_argument("--cls_model_path", type=str, default=None)
cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192])
cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"])
cls_group.add_argument("--cls_batch_num", type=int, default=6)
cls_group.add_argument("--cls_thresh", type=float, default=0.9)
rec_group = parser.add_argument_group(title="Rec")
rec_group.add_argument("--rec_model_path", type=str, default=None)
rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320])
rec_group.add_argument("--rec_batch_num", type=int, default=6)
args = parser.parse_args()
return args
class UpdateParameters:
def __init__(self) -> None:
pass
def parse_kwargs(self, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {}
for k, v in kwargs.items():
if k.startswith("det"):
det_dict[k] = v
elif k.startswith("cls"):
cls_dict[k] = v
elif k.startswith("rec"):
rec_dict[k] = v
else:
global_dict[k] = v
return global_dict, det_dict, cls_dict, rec_dict
def __call__(self, config, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs)
new_config = {
"Global": self.update_global_params(config["Global"], global_dict),
"Det": self.update_det_params(config["Det"], det_dict),
"Cls": self.update_cls_params(config["Cls"], cls_dict),
"Rec": self.update_rec_params(config["Rec"], rec_dict),
}
return new_config
def update_global_params(self, config, global_dict):
if global_dict:
config.update(global_dict)
return config
def update_det_params(self, config, det_dict):
if not det_dict:
return config
det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()}
model_path = det_dict.get('model_path', None)
if not model_path:
det_dict["model_path"] = str(root_dir / config["model_path"])
config.update(det_dict)
return config
def update_cls_params(self, config, cls_dict):
if not cls_dict:
return config
need_remove_prefix = ["cls_label_list", "cls_model_path"]
new_cls_dict = self.remove_prefix(cls_dict, 'cls_', need_remove_prefix)
model_path = new_cls_dict.get('model_path', None)
if model_path:
new_cls_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_cls_dict)
return config
def update_rec_params(self, config, rec_dict):
if not rec_dict:
return config
need_remove_prefix = ["rec_model_path"]
new_rec_dict = self.remove_prefix(rec_dict, 'rec_', need_remove_prefix)
model_path = new_rec_dict.get('model_path', None)
if not model_path:
new_rec_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_rec_dict)
return config
@staticmethod
def remove_prefix(
config: Dict[str, str], prefix: str, remove_params: List[str]
) -> Dict[str, str]:
new_rec_dict = {}
for k, v in config.items():
if k in remove_params:
k = k.split(prefix)[1]
new_rec_dict[k] = v
return new_rec_dict
pyclipper>=1.2.0
Shapely>=1.7.1
PyYAML
pytest
\ No newline at end of file
# RapidOCR
## 模型介绍
目前已知运行速度最快、支持最广,完全开源免费并支持离线快速部署的多平台多语言OCR。。
## 模型结构
RapidOCR使用ch_PP-OCRv3_det + ch_ppocr_mobile_v2.0_cls + ch_PP-OCRv3_rec三个模型进行图像中的文本识别。
## Python版本推理
本次采用RapidOCR模型基于ONNXRuntime推理框架进行图像文本识别,模型文件下载链接:https://pan.baidu.com/s/1uGHhimKLb5k5f9xaFmNBwQ , 提取码:ggvz ,并将ch_PP-OCRv3_det_infer.onnx、ch_ppocr_mobile_v2.0_cls_infer.onnx、ch_PP-OCRv3_rec_infer.onnx模型文件保存在Resource/Models文件夹下。下面介绍如何运行python代码示例,Python示例的详细说明见Doc目录下的Tutorial_Python.md。
### 下载镜像
在光源中下载镜像:
```python
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort1.14.0_migraphx3.0.0-dtk22.10.1
```
### 设置Python环境变量
```
export PYTHONPATH=/opt/dtk/lib:$PYTHONPATH
```
### 安装依赖
```python
# 进入rapidocr ort工程根目录
cd <path_to_rapidocr_ort>
# 进入示例程序目录
cd Python/
# 安装依赖
pip install -r requirements.txt
```
### 运行示例
```python
python rapidocr.py
```
如下所示,通过输入图像,RapidOcr模型可以识别出文字和文本框。
```
[[[[245.0, 9.0], [554.0, 8.0], [554.0, 27.0], [245.0, 28.0]], '人生活的真实写照:善有善报,恶有恶报。', '0.9306996673345566'], [[[9.0, 49.0], [522.0, 50.0], [522.0, 69.0], [9.0, 68.0]], '我们中国人有一句俗语说:“种瓜得瓜,种豆得豆。”而这就是每个', '0.9294075581335253'], [[[84.0, 105.0], [555.0, 104.0], [555.0, 125.0], [85.0, 127.0]], "every man's life: good begets good, and evil leads to evil.", '0.8932319914301237'], [[[28.0, 147.0], [556.0, 146.0], [556.0, 168.0], [28.0, 169.0]], 'melons; if he sows beans, he will reap beans." And this is true of', '0.900923888185131'], [[[0.0, 185.0], [524.0, 188.0], [524.0, 212.0], [0.0, 209.0]], 'We Chinese have a saying:"If a man plants melons, he will reap', '0.9216671202863965'], [[[295.0, 248.0], [553.0, 248.0], [553.0, 264.0], [295.0, 264.0]], '它不仅适用于今生,也适用于来世。', '0.927988795673146'], [[[14.0, 289.0], [554.0, 290.0], [554.0, 307.0], [14.0, 306.0]], '一每一个行为都有一种结果。在我看来,这种想法是全宇宙的道德基础;', '0.88565122719967'], [[[9.0, 330.0], [521.0, 330.0], [521.0, 349.0], [9.0, 349.0]], '假如说过去的日子曾经教给我们一些什么的话,那就是有因必有果一', '0.9162070232052957'], [[[343.0, 388.0], [555.0, 388.0], [555.0, 405.0], [343.0, 405.0]], 'in this world and the next.', '0.8764956444501877'], [[[15.0, 426.0], [554.0, 426.0], [554.0, 448.0], [15.0, 448.0]], 'opinion, is the moral foundation of the universe; it applies equally', '0.9183026262815448'], [[[62.0, 466.0], [556.0, 468.0], [556.0, 492.0], [62.0, 490.0]], 'effect - every action has a consequence. This thought, in my', '0.9308378403304053']]
```
## C++版本推理
本次采用RapidOCR模型基于ONNXRuntime推理框架进行图像文本识别,模型文件下载链接:https://pan.baidu.com/s/1uGHhimKLb5k5f9xaFmNBwQ , 提取码:ggvz ,并将ch_PP-OCRv3_det_infer.onnx、ch_ppocr_mobile_v2.0_cls_infer.onnx、ch_PP-OCRv3_rec_infer.onnx模型文件保存在Resource/Models文件夹下。下面介绍如何运行python代码示例,Python示例的详细说明见Doc目录下的Tutorial_Cpp.md。
### 下载镜像
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort1.14.0_migraphx3.0.0-dtk22.10.1
```
### 构建工程
```
rbuild build -d depend
```
### 设置环境变量
将依赖库依赖加入环境变量LD_LIBRARY_PATH,在~/.bashrc中添加如下语句:
```
export LD_LIBRARY_PATH=<path_to_rapidocr_ort>/depend/lib64/:$LD_LIBRARY_PATH
```
然后执行:
```
source ~/.bashrc
source /opt/dtk/env.sh
```
### 运行示例
```cpp
# 进入rapidocr ort工程根目录
cd <path_to_rapidocr_ort>
# 进入build目录
cd build/
# 执行示例程序
./RapidOcr
```
如下所示,通过输入图像,RapidOcr模型可以识别出文字和文本框,结果保存在/Resource/Images/文件夹中。
```
TextBox[0](+padding)[score(0.711119),[x: 293, y: 58], [x: 604, y: 58], [x: 604, y: 79], [x: 293, y: 79]]
...
TextBox[11](+padding)[score(0.605026),[x: 92, y: 554], [x: 610, y: 557], [x: 609, y: 585], [x: 92, y: 582]]
---------- step: drawTextBoxes ----------
---------- step: angleNet getAngles ----------
angle[0][index(1), score(1.000000), time(57.276707ms)]
...
angle[11][index(1), score(0.930842), time(2.952602ms)]
---------- step: crnnNet getTextLine ----------
textLine[0](人生活的真实写照:善有善报,恶有恶报。)
textScores[0]{0.576271 ,0.99956 ,0.999475 ,0.99967 ,0.998779 ,0.999525 ,0.805865 ,0.999865 ,0.988233 ,0.999061 ,0.999581 ,0.999483 ,0.999324 ,0.995648 ,0.561861 ,0.961845 ,0.995993 ,0.998593 ,0.994963}
crnnTime[0](58.019418ms)
...
textLine[11](If the past has taught us anything, it is that every cause brings)
textScores[11]{0.996653 ,0.625094 ,0.97989 ,0.999761 ,0.816289 ,0.99883 ,0.963821 ,0.999222 ,0.999725 ,0.999588 ,0.542554 ,0.998707 ,0.911063 ,0.603935 ,0.99833 ,0.994734 ,0.998606 ,0.999571 ,0.9995 ,0.99971 ,0.983833 ,0.941867 ,0.989647 ,0.999145 ,0.998365 ,0.995752 ,0.999369 ,0.999424 ,0.976135 ,0.998815 ,0.999755 ,0.67898 ,0.999837 ,0.999205 ,0.982815 ,0.991013 ,0.999252 ,0.818822 ,0.996863 ,0.998451 ,0.999198 ,0.812635 ,0.999701 ,0.567811 ,0.999545 ,0.815998 ,0.996471 ,0.998722 ,0.999546 ,0.999121 ,0.999202 ,0.99971 ,0.980306 ,0.999399 ,0.635116 ,0.99954 ,0.998961 ,0.600432 ,0.990555 ,0.999872 ,0.998974 ,0.999687 ,0.56602 ,0.999607 ,0.999343}
crnnTime[11](38.051758ms)
```
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/rapidocr_ort
## 参考
https://github.com/RapidAI/RapidOCR
https://github.com/RapidAI/RapidOcrOnnx
\ No newline at end of file
This diff is collapsed.
#include "AngleNet.h"
#include "OcrUtils.h"
#include <numeric>
void AngleNet::setGpuIndex(int gpuIndex) {
}
AngleNet::~AngleNet() {
delete session;
inputNamesPtr.clear();
outputNamesPtr.clear();
}
void AngleNet::setNumThread(int numOfThread) {
numThread = numOfThread;
sessionOptions.SetInterOpNumThreads(numThread);
}
void AngleNet::initModel(const std::string &pathStr) {
//设置DCU
OrtROCMProviderOptions rocm_options;
rocm_options.device_id = 0;
sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
session = new Ort::Session(env, pathStr.c_str(), sessionOptions);
inputNamesPtr = getInputNames(session);
outputNamesPtr = getOutputNames(session);
}
Angle scoreToAngle(const std::vector<float> &outputData) {
int maxIndex = 0;
float maxScore = 0;
for (size_t i = 0; i < outputData.size(); i++) {
if (outputData[i] > maxScore) {
maxScore = outputData[i];
maxIndex = i;
}
}
return {maxIndex, maxScore};
}
Angle AngleNet::getAngle(cv::Mat &src) {
std::vector<float> inputTensorValues = substractMeanNormalize(src, meanValues, normValues);
std::array<int64_t, 4> inputShape{1, src.channels(), src.rows, src.cols};
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, inputTensorValues.data(),
inputTensorValues.size(), inputShape.data(),
inputShape.size());
assert(inputTensor.IsTensor());
std::vector<const char *> inputNames = {inputNamesPtr.data()->get()};
std::vector<const char *> outputNames = {outputNamesPtr.data()->get()};
auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor,
inputNames.size(), outputNames.data(), outputNames.size());
assert(outputTensor.size() == 1 && outputTensor.front().IsTensor());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1,
std::multiplies<int64_t>());
float *floatArray = outputTensor.front().GetTensorMutableData<float>();
std::vector<float> outputData(floatArray, floatArray + outputCount);
return scoreToAngle(outputData);
}
std::vector<Angle> AngleNet::getAngles(std::vector<cv::Mat> &partImgs, const char *path,
const char *imgName, bool doAngle, bool mostAngle) {
size_t size = partImgs.size();
std::vector<Angle> angles(size);
if (doAngle) {
for (size_t i = 0; i < size; ++i) {
double startAngle = getCurrentTime();
cv::Mat angleImg;
cv::resize(partImgs[i], angleImg, cv::Size(dstWidth, dstHeight));
Angle angle = getAngle(angleImg);
double endAngle = getCurrentTime();
angle.time = endAngle - startAngle;
angles[i] = angle;
//输出img
if (isOutputAngleImg) {
std::string angleImgFile = getDebugImgFilePath(path, imgName, i, "-angle-");
saveImg(angleImg, angleImgFile.c_str());
}
}
} else {
for (size_t i = 0; i < size; ++i) {
angles[i] = Angle{-1, 0.f};
}
}
//最可能的角度索引
if (doAngle && mostAngle) {
auto angleIndexes = getAngleIndexes(angles);
double sum = std::accumulate(angleIndexes.begin(), angleIndexes.end(), 0.0);
double halfPercent = angles.size() / 2.0f;
int mostAngleIndex;
if (sum < halfPercent) {//all angle set to 0
mostAngleIndex = 0;
} else {//all angle set to 1
mostAngleIndex = 1;
}
for (size_t i = 0; i < angles.size(); ++i) {
Angle angle = angles[i];
angle.index = mostAngleIndex;
angles.at(i) = angle;
}
}
return angles;
}
\ No newline at end of file
#include "CrnnNet.h"
#include "OcrUtils.h"
#include <fstream>
#include <numeric>
void CrnnNet::setGpuIndex(int gpuIndex) {
}
CrnnNet::~CrnnNet() {
delete session;
inputNamesPtr.clear();
outputNamesPtr.clear();
}
void CrnnNet::setNumThread(int numOfThread) {
numThread = numOfThread;
sessionOptions.SetInterOpNumThreads(numThread);
}
void CrnnNet::initModel(const std::string &pathStr, const std::string &keysPath) {
//设置DCU
OrtROCMProviderOptions rocm_options;
rocm_options.device_id = 0;
sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
session = new Ort::Session(env, pathStr.c_str(), sessionOptions);
inputNamesPtr = getInputNames(session);
outputNamesPtr = getOutputNames(session);
//load keys
std::ifstream in(keysPath.c_str());
std::string line;
if (in) {
while (getline(in, line)) {// line中不包括每行的换行符
keys.push_back(line);
}
} else {
printf("The keys.txt file was not found\n");
return;
}
keys.insert(keys.begin(), "#");
keys.emplace_back(" ");
printf("total keys size(%lu)\n", keys.size());
}
template<class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last));
}
TextLine CrnnNet::scoreToTextLine(const std::vector<float> &outputData, size_t h, size_t w) {
auto keySize = keys.size();
auto dataSize = outputData.size();
std::string strRes;
std::vector<float> scores;
size_t lastIndex = 0;
size_t maxIndex;
float maxValue;
for (size_t i = 0; i < h; i++) {
size_t start = i * w;
size_t stop = (i + 1) * w;
if (stop > dataSize - 1) {
stop = (i + 1) * w - 1;
}
maxIndex = int(argmax(&outputData[start], &outputData[stop]));
maxValue = float(*std::max_element(&outputData[start], &outputData[stop]));
if (maxIndex > 0 && maxIndex < keySize && (!(i > 0 && maxIndex == lastIndex))) {
scores.emplace_back(maxValue);
strRes.append(keys[maxIndex]);
}
lastIndex = maxIndex;
}
return {strRes, scores};
}
TextLine CrnnNet::getTextLine(const cv::Mat &src) {
float scale = (float) dstHeight / (float) src.rows;
int dstWidth = int((float) src.cols * scale);
cv::Mat srcResize;
resize(src, srcResize, cv::Size(dstWidth, dstHeight));
std::vector<float> inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues);
std::array<int64_t, 4> inputShape{1, srcResize.channels(), srcResize.rows, srcResize.cols};
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, inputTensorValues.data(),
inputTensorValues.size(), inputShape.data(),
inputShape.size());
assert(inputTensor.IsTensor());
std::vector<const char *> inputNames = {inputNamesPtr.data()->get()};
std::vector<const char *> outputNames = {outputNamesPtr.data()->get()};
auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor,
inputNames.size(), outputNames.data(), outputNames.size());
assert(outputTensor.size() == 1 && outputTensor.front().IsTensor());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1,
std::multiplies<int64_t>());
float *floatArray = outputTensor.front().GetTensorMutableData<float>();
std::vector<float> outputData(floatArray, floatArray + outputCount);
return scoreToTextLine(outputData, outputShape[1], outputShape[2]);
}
std::vector<TextLine> CrnnNet::getTextLines(std::vector<cv::Mat> &partImg, const char *path, const char *imgName) {
int size = partImg.size();
std::vector<TextLine> textLines(size);
for (int i = 0; i < size; ++i) {
//OutPut DebugImg
if (isOutputDebugImg) {
std::string debugImgFile = getDebugImgFilePath(path, imgName, i, "-debug-");
saveImg(partImg[i], debugImgFile.c_str());
}
//getTextLine
double startCrnnTime = getCurrentTime();
TextLine textLine = getTextLine(partImg[i]);
double endCrnnTime = getCurrentTime();
textLine.time = endCrnnTime - startCrnnTime;
textLines[i] = textLine;
}
return textLines;
}
\ No newline at end of file
#include "DbNet.h"
#include "OcrUtils.h"
void DbNet::setGpuIndex(int gpuIndex) {
}
DbNet::~DbNet() {
delete session;
inputNamesPtr.clear();
outputNamesPtr.clear();
}
void DbNet::setNumThread(int numOfThread) {
numThread = numOfThread;
sessionOptions.SetInterOpNumThreads(numThread);
}
void DbNet::initModel(const std::string &pathStr) {
//设置DCU
OrtROCMProviderOptions rocm_options;
rocm_options.device_id = 0;
sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
session = new Ort::Session(env, pathStr.c_str(), sessionOptions);
inputNamesPtr = getInputNames(session);
outputNamesPtr = getOutputNames(session);
}
std::vector<TextBox> findRsBoxes(const cv::Mat &predMat, const cv::Mat &dilateMat, ScaleParam &s,
const float boxScoreThresh, const float unClipRatio) {
const int longSideThresh = 3;//minBox 长边门限
const int maxCandidates = 1000;
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(dilateMat, contours, hierarchy, cv::RETR_LIST,
cv::CHAIN_APPROX_SIMPLE);
size_t numContours = contours.size() >= maxCandidates ? maxCandidates : contours.size();
std::vector<TextBox> rsBoxes;
for (size_t i = 0; i < numContours; i++) {
if (contours[i].size() <= 2) {
continue;
}
cv::RotatedRect minAreaRect = cv::minAreaRect(contours[i]);
float longSide;
std::vector<cv::Point2f> minBoxes = getMinBoxes(minAreaRect, longSide);
if (longSide < longSideThresh) {
continue;
}
float boxScore = boxScoreFast(minBoxes, predMat);
if (boxScore < boxScoreThresh)
continue;
//-----unClip-----
cv::RotatedRect clipRect = unClip(minBoxes, unClipRatio);
if (clipRect.size.height < 1.001 && clipRect.size.width < 1.001) {
continue;
}
//-----unClip-----
std::vector<cv::Point2f> clipMinBoxes = getMinBoxes(clipRect, longSide);
if (longSide < longSideThresh + 2)
continue;
std::vector<cv::Point> intClipMinBoxes;
for (auto &clipMinBox: clipMinBoxes) {
float x = clipMinBox.x / s.ratioWidth;
float y = clipMinBox.y / s.ratioHeight;
int ptX = (std::min)((std::max)(int(x), 0), s.srcWidth - 1);
int ptY = (std::min)((std::max)(int(y), 0), s.srcHeight - 1);
cv::Point point{ptX, ptY};
intClipMinBoxes.push_back(point);
}
rsBoxes.push_back(TextBox{intClipMinBoxes, boxScore});
}
reverse(rsBoxes.begin(), rsBoxes.end());
return rsBoxes;
}
std::vector<TextBox>
DbNet::getTextBoxes(cv::Mat &src, ScaleParam &s, float boxScoreThresh, float boxThresh, float unClipRatio) {
cv::Mat srcResize;
resize(src, srcResize, cv::Size(s.dstWidth, s.dstHeight));
std::vector<float> inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues);
std::array<int64_t, 4> inputShape{1, srcResize.channels(), srcResize.rows, srcResize.cols};
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, inputTensorValues.data(),
inputTensorValues.size(), inputShape.data(),
inputShape.size());
assert(inputTensor.IsTensor());
std::vector<const char *> inputNames = {inputNamesPtr.data()->get()};
std::vector<const char *> outputNames = {outputNamesPtr.data()->get()};
auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor,
inputNames.size(), outputNames.data(), outputNames.size());
assert(outputTensor.size() == 1 && outputTensor.front().IsTensor());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1,
std::multiplies<int64_t>());
float *floatArray = outputTensor.front().GetTensorMutableData<float>();
std::vector<float> outputData(floatArray, floatArray + outputCount);
//-----Data preparation-----
int outHeight = (int) outputShape[2];
int outWidth = (int) outputShape[3];
size_t area = outHeight * outWidth;
std::vector<float> predData(area, 0.0);
std::vector<unsigned char> cbufData(area, ' ');
for (int i = 0; i < area; i++) {
predData[i] = float(outputData[i]);
cbufData[i] = (unsigned char) ((outputData[i]) * 255);
}
cv::Mat predMat(outHeight, outWidth, CV_32F, (float *) predData.data());
cv::Mat cBufMat(outHeight, outWidth, CV_8UC1, (unsigned char *) cbufData.data());
//-----boxThresh-----
const double maxValue = 255;
const double threshold = boxThresh * 255;
cv::Mat thresholdMat;
cv::threshold(cBufMat, thresholdMat, threshold, maxValue, cv::THRESH_BINARY);
//-----dilate-----
cv::Mat dilateMat;
cv::Mat dilateElement = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(thresholdMat, dilateMat, dilateElement);
return findRsBoxes(predMat, dilateMat, s, boxScoreThresh, unClipRatio);
}
\ No newline at end of file
#include "OcrLite.h"
#include "OcrUtils.h"
#include <stdarg.h> //windows&linux
OcrLite::OcrLite() {}
OcrLite::~OcrLite() {
if (isOutputResultTxt) {
fclose(resultTxt);
}
}
void OcrLite::setNumThread(int numOfThread) {
dbNet.setNumThread(numOfThread);
angleNet.setNumThread(numOfThread);
crnnNet.setNumThread(numOfThread);
}
void OcrLite::initLogger(bool isConsole, bool isPartImg, bool isResultImg) {
isOutputConsole = isConsole;
isOutputPartImg = isPartImg;
isOutputResultImg = isResultImg;
}
void OcrLite::enableResultTxt(const char *path, const char *imgName) {
isOutputResultTxt = true;
std::string resultTxtPath = getResultTxtFilePath(path, imgName);
printf("resultTxtPath(%s)\n", resultTxtPath.c_str());
resultTxt = fopen(resultTxtPath.c_str(), "w");
}
void OcrLite::setGpuIndex(int gpuIndex) {
dbNet.setGpuIndex(gpuIndex);
angleNet.setGpuIndex(-1);
crnnNet.setGpuIndex(gpuIndex);
}
bool OcrLite::initModels(const std::string &detPath, const std::string &clsPath,
const std::string &recPath, const std::string &keysPath) {
Logger("=====Init Models=====\n");
Logger("--- Init DbNet ---\n");
dbNet.initModel(detPath);
Logger("--- Init AngleNet ---\n");
angleNet.initModel(clsPath);
Logger("--- Init CrnnNet ---\n");
crnnNet.initModel(recPath, keysPath);
Logger("Init Models Success!\n");
return true;
}
void OcrLite::Logger(const char *format, ...) {
if (!(isOutputConsole || isOutputResultTxt)) return;
char *buffer = (char *) malloc(8192);
va_list args;
va_start(args, format);
vsprintf(buffer, format, args);
va_end(args);
if (isOutputConsole) printf("%s", buffer);
if (isOutputResultTxt) fprintf(resultTxt, "%s", buffer);
free(buffer);
}
cv::Mat makePadding(cv::Mat &src, const int padding) {
if (padding <= 0) return src;
cv::Scalar paddingScalar = {255, 255, 255};
cv::Mat paddingSrc;
cv::copyMakeBorder(src, paddingSrc, padding, padding, padding, padding, cv::BORDER_ISOLATED, paddingScalar);
return paddingSrc;
}
OcrResult OcrLite::detect(const char *path, const char *imgName,
const int padding, const int maxSideLen,
float boxScoreThresh, float boxThresh, float unClipRatio, bool doAngle, bool mostAngle) {
std::string imgFile = getSrcImgFilePath(path, imgName);
cv::Mat originSrc = imread(imgFile, cv::IMREAD_COLOR);//default : BGR
int originMaxSide = (std::max)(originSrc.cols, originSrc.rows);
int resize;
if (maxSideLen <= 0 || maxSideLen > originMaxSide) {
resize = originMaxSide;
} else {
resize = maxSideLen;
}
resize += 2 * padding;
cv::Rect paddingRect(padding, padding, originSrc.cols, originSrc.rows);
cv::Mat paddingSrc = makePadding(originSrc, padding);
ScaleParam scale = getScaleParam(paddingSrc, resize);
OcrResult result;
result = detect(path, imgName, paddingSrc, paddingRect, scale,
boxScoreThresh, boxThresh, unClipRatio, doAngle, mostAngle);
return result;
}
OcrResult OcrLite::detect(const cv::Mat &mat, int padding, int maxSideLen, float boxScoreThresh, float boxThresh,
float unClipRatio, bool doAngle, bool mostAngle) {
cv::Mat originSrc = mat;
int originMaxSide = (std::max)(originSrc.cols, originSrc.rows);
int resize;
if (maxSideLen <= 0 || maxSideLen > originMaxSide) {
resize = originMaxSide;
} else {
resize = maxSideLen;
}
resize += 2 * padding;
cv::Rect paddingRect(padding, padding, originSrc.cols, originSrc.rows);
cv::Mat paddingSrc = makePadding(originSrc, padding);
ScaleParam scale = getScaleParam(paddingSrc, resize);
OcrResult result;
result = detect(NULL, NULL, paddingSrc, paddingRect, scale,
boxScoreThresh, boxThresh, unClipRatio, doAngle, mostAngle);
return result;
}
std::vector<cv::Mat> OcrLite::getPartImages(cv::Mat &src, std::vector<TextBox> &textBoxes,
const char *path, const char *imgName) {
std::vector<cv::Mat> partImages;
for (size_t i = 0; i < textBoxes.size(); ++i) {
cv::Mat partImg = getRotateCropImage(src, textBoxes[i].boxPoint);
partImages.emplace_back(partImg);
//OutPut DebugImg
if (isOutputPartImg) {
std::string debugImgFile = getDebugImgFilePath(path, imgName, i, "-part-");
saveImg(partImg, debugImgFile.c_str());
}
}
return partImages;
}
OcrResult OcrLite::detect(const char *path, const char *imgName,
cv::Mat &src, cv::Rect &originRect, ScaleParam &scale,
float boxScoreThresh, float boxThresh, float unClipRatio, bool doAngle, bool mostAngle) {
cv::Mat textBoxPaddingImg = src.clone();
int thickness = getThickness(src);
Logger("=====Start detect=====\n");
Logger("ScaleParam(sw:%d,sh:%d,dw:%d,dh:%d,%f,%f)\n", scale.srcWidth, scale.srcHeight,
scale.dstWidth, scale.dstHeight,
scale.ratioWidth, scale.ratioHeight);
Logger("---------- step: dbNet getTextBoxes ----------\n");
double startTime = getCurrentTime();
std::vector<TextBox> textBoxes = dbNet.getTextBoxes(src, scale, boxScoreThresh, boxThresh, unClipRatio);
double endDbNetTime = getCurrentTime();
double dbNetTime = endDbNetTime - startTime;
Logger("dbNetTime(%fms)\n", dbNetTime);
for (size_t i = 0; i < textBoxes.size(); ++i) {
Logger("TextBox[%d](+padding)[score(%f),[x: %d, y: %d], [x: %d, y: %d], [x: %d, y: %d], [x: %d, y: %d]]\n", i,
textBoxes[i].score,
textBoxes[i].boxPoint[0].x, textBoxes[i].boxPoint[0].y,
textBoxes[i].boxPoint[1].x, textBoxes[i].boxPoint[1].y,
textBoxes[i].boxPoint[2].x, textBoxes[i].boxPoint[2].y,
textBoxes[i].boxPoint[3].x, textBoxes[i].boxPoint[3].y);
}
Logger("---------- step: drawTextBoxes ----------\n");
drawTextBoxes(textBoxPaddingImg, textBoxes, thickness);
//---------- getPartImages ----------
std::vector<cv::Mat> partImages = getPartImages(src, textBoxes, path, imgName);
Logger("---------- step: angleNet getAngles ----------\n");
std::vector<Angle> angles;
angles = angleNet.getAngles(partImages, path, imgName, doAngle, mostAngle);
//Log Angles
for (size_t i = 0; i < angles.size(); ++i) {
Logger("angle[%d][index(%d), score(%f), time(%fms)]\n", i, angles[i].index, angles[i].score, angles[i].time);
}
//Rotate partImgs
for (size_t i = 0; i < partImages.size(); ++i) {
if (angles[i].index == 1) {
partImages.at(i) = matRotateClockWise180(partImages[i]);
}
}
Logger("---------- step: crnnNet getTextLine ----------\n");
std::vector<TextLine> textLines = crnnNet.getTextLines(partImages, path, imgName);
//Log TextLines
for (size_t i = 0; i < textLines.size(); ++i) {
Logger("textLine[%d](%s)\n", i, textLines[i].text.c_str());
std::ostringstream txtScores;
for (size_t s = 0; s < textLines[i].charScores.size(); ++s) {
if (s == 0) {
txtScores << textLines[i].charScores[s];
} else {
txtScores << " ," << textLines[i].charScores[s];
}
}
Logger("textScores[%d]{%s}\n", i, std::string(txtScores.str()).c_str());
Logger("crnnTime[%d](%fms)\n", i, textLines[i].time);
}
std::vector<TextBlock> textBlocks;
for (size_t i = 0; i < textLines.size(); ++i) {
std::vector<cv::Point> boxPoint = std::vector<cv::Point>(4);
int padding = originRect.x;//padding conversion
boxPoint[0] = cv::Point(textBoxes[i].boxPoint[0].x - padding, textBoxes[i].boxPoint[0].y - padding);
boxPoint[1] = cv::Point(textBoxes[i].boxPoint[1].x - padding, textBoxes[i].boxPoint[1].y - padding);
boxPoint[2] = cv::Point(textBoxes[i].boxPoint[2].x - padding, textBoxes[i].boxPoint[2].y - padding);
boxPoint[3] = cv::Point(textBoxes[i].boxPoint[3].x - padding, textBoxes[i].boxPoint[3].y - padding);
TextBlock textBlock{boxPoint, textBoxes[i].score, angles[i].index, angles[i].score,
angles[i].time, textLines[i].text, textLines[i].charScores, textLines[i].time,
angles[i].time + textLines[i].time};
textBlocks.emplace_back(textBlock);
}
double endTime = getCurrentTime();
double fullTime = endTime - startTime;
Logger("=====End detect=====\n");
Logger("FullDetectTime(%fms)\n", fullTime);
//cropped to original size
cv::Mat textBoxImg;
if (originRect.x > 0 && originRect.y > 0) {
textBoxPaddingImg(originRect).copyTo(textBoxImg);
} else {
textBoxImg = textBoxPaddingImg;
}
//Save result.jpg
if (isOutputResultImg) {
std::string resultImgFile = getResultImgFilePath(path, imgName);
imwrite(resultImgFile, textBoxImg);
}
std::string strRes;
for (auto &textBlock: textBlocks) {
strRes.append(textBlock.text);
strRes.append("\n");
}
return OcrResult{dbNetTime, textBlocks, textBoxImg, fullTime, strRes};
}
\ No newline at end of file
#ifdef __CLIB__
#include "OcrLiteCApi.h"
#include "OcrLite.h"
extern "C"
{
typedef struct {
OcrLite OcrObj;
std::string strRes;
} OCR_OBJ;
_QM_OCR_API OCR_HANDLE
OcrInit(const char *szDetModel, const char *szClsModel, const char *szRecModel, const char *szKeyPath, int nThreads) {
OCR_OBJ *pOcrObj = new OCR_OBJ;
if (pOcrObj) {
pOcrObj->OcrObj.setNumThread(nThreads);
pOcrObj->OcrObj.initModels(szDetModel, szClsModel, szRecModel, szKeyPath);
return pOcrObj;
} else {
return nullptr;
}
}
_QM_OCR_API OCR_BOOL
OcrDetect(OCR_HANDLE handle, const char *imgPath, const char *imgName, OCR_PARAM *pParam) {
OCR_OBJ *pOcrObj = (OCR_OBJ *) handle;
if (!pOcrObj)
return FALSE;
OCR_PARAM Param = *pParam;
if (Param.padding == 0)
Param.padding = 50;
if (Param.maxSideLen == 0)
Param.maxSideLen = 1024;
if (Param.boxScoreThresh == 0)
Param.boxScoreThresh = 0.6;
if (Param.boxThresh == 0)
Param.boxThresh = 0.3f;
if (Param.unClipRatio == 0)
Param.unClipRatio = 2.0;
if (Param.doAngle == 0)
Param.doAngle = 1;
if (Param.mostAngle == 0)
Param.mostAngle = 1;
OcrResult result = pOcrObj->OcrObj.detect(imgPath, imgName, Param.padding, Param.maxSideLen,
Param.boxScoreThresh, Param.boxThresh, Param.unClipRatio,
Param.doAngle != 0, Param.mostAngle != 0);
if (result.strRes.length() > 0) {
pOcrObj->strRes = result.strRes;
return TRUE;
} else
return FALSE;
}
_QM_OCR_API int OcrGetLen(OCR_HANDLE handle) {
OCR_OBJ *pOcrObj = (OCR_OBJ *) handle;
if (!pOcrObj)
return 0;
return pOcrObj->strRes.size() + 1;
}
_QM_OCR_API OCR_BOOL OcrGetResult(OCR_HANDLE handle, char *szBuf, int nLen) {
OCR_OBJ *pOcrObj = (OCR_OBJ *) handle;
if (!pOcrObj)
return FALSE;
if (nLen > pOcrObj->strRes.size()) {
strncpy(szBuf, pOcrObj->strRes.c_str(), pOcrObj->strRes.size());
szBuf[pOcrObj->strRes.size() - 1] = 0;
}
return pOcrObj->strRes.size();
}
_QM_OCR_API void OcrDestroy(OCR_HANDLE handle) {
OCR_OBJ *pOcrObj = (OCR_OBJ *) handle;
if (pOcrObj)
delete pOcrObj;
}
};
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment