Commit 19d66e62 authored by zhoujun's avatar zhoujun
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into py_inference_doc

parents 204ab814 055f207f
......@@ -271,6 +271,59 @@ im_show.save('result.jpg')
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
```
### Use web images or numpy array as input
1. Web image
Use by code
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# show result
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
Use by command line
```bash
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
```
2. Numpy array
Support numpy array as input only when used by code
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# show result
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
## Parameter Description
| Parameter | Description | Default value |
......@@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
| use_space_char | Whether to recognize spaces | TRUE |
| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 |
| use_angle_cls | Whether to load classification model | FALSE |
| cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
| cls_image_shape | image shape of classification algorithm | "3,48,192" |
......@@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec | FALSE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
......@@ -26,17 +26,50 @@ import requests
from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.utility import initial_logger
from ppocr.utils.logging import get_logger
logger = initial_logger()
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
__all__ = ['PaddleOCR']
model_params = {
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
'rec':
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
model_urls = {
'det':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar',
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/ic15_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
}
},
'cls':
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar'
}
SUPPORT_DET_MODEL = ['DB']
......@@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
logger.error("ERROR, something went wrong")
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
......@@ -84,13 +117,14 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path)
def parse_args():
def parse_args(mMain=True, add_help=True):
import argparse
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
if mMain:
parser = argparse.ArgumentParser(add_help=add_help)
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
......@@ -101,7 +135,8 @@ def parse_args():
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_max_side_len", type=float, default=960)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
......@@ -120,17 +155,64 @@ def parse_args():
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
"--rec_char_dict_path",
type=str,
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args()
else:
return argparse.Namespace(use_gpu=True,
ir_optim=True,
use_tensorrt=False,
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=2.0,
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=30,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=30,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False
)
class PaddleOCR(predict_system.TextSystem):
......@@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params = parse_args()
postprocess_params = parse_args(mMain=False, add_help=False)
postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
if postprocess_params.rec_char_dict_path is None:
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir
if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec')
postprocess_params.rec_model_dir = os.path.join(
BASE_DIR, 'rec/{}'.format(lang))
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
# download model
maybe_download(postprocess_params.det_model_dir, model_params['det'])
maybe_download(postprocess_params.rec_model_dir, model_params['rec'])
maybe_download(postprocess_params.det_model_dir, model_urls['det'])
maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
......@@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model
super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True):
def ocr(self, img, det=True, rec=True, cls=False):
"""
ocr with paddleocr
args:
......@@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem):
rec: use text recognition or not, if false, only det will be exec. default is True
"""
assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false')
exit(0)
self.use_angle_cls = cls
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
......@@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem):
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec:
dt_boxes, rec_res = self.__call__(img)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
......@@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem):
else:
if not isinstance(img, list):
img = [img]
if self.use_angle_cls:
img, cls_res, elapse = self.text_classifier(img)
if not rec:
return cls_res
rec_res, elapse = self.text_recognizer(img)
return rec_res
def main():
# for com
args = parse_args()
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
ocr_engine = PaddleOCR()
ocr_engine = PaddleOCR(**(args.__dict__))
for img_path in image_file_list:
print(img_path)
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = ocr_engine.ocr(img_path,
det=args.det,
rec=args.rec,
cls=args.use_angle_cls)
if result is not None:
for line in result:
print(line)
\ No newline at end of file
logger.info(line)
......@@ -26,6 +26,9 @@ from .randaugment import RandAugment
from .operators import *
from .label_ops import *
from .east_process import *
from .sast_process import *
def transform(data, ops=None):
""" transform """
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import cv2
import numpy as np
import json
import sys
import os
__all__ = ['EASTProcessTrain']
class EASTProcessTrain(object):
def __init__(self,
image_shape = [512, 512],
background_ratio = 0.125,
min_crop_side_ratio = 0.1,
min_text_size = 10,
**kwargs):
self.input_size = image_shape[1]
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
self.background_ratio = background_ratio
self.min_crop_side_ratio = min_crop_side_ratio
self.min_text_size = min_text_size
def preprocess(self, im):
input_size = self.input_size
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
# im = im[:, :, ::-1].astype(np.float32)
im = im / 255
im -= img_mean
im /= img_std
new_h, new_w, _ = im.shape
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
im_padded[:new_h, :new_w, :] = im
im_padded = im_padded.transpose((2, 0, 1))
im_padded = im_padded[np.newaxis, :]
return im_padded, im_scale
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
if 0.333 < rand_degree_ratio < 0.666:
rand_degree_cnt = 2
elif rand_degree_ratio > 0.666:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4):
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx)\
- math.sin(rot_angle) * (sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx)\
+ math.cos(rot_angle) * (sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
dst_polys = np.array(dst_polys, dtype=np.float32)
return dst_im, dst_polys
def polygon_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def check_and_validate_polys(self, polys, tags, img_height, img_width):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
h, w = img_height, img_width
if polys.shape[0] == 0:
return polys
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
for poly, tag in zip(polys, tags):
p_area = self.polygon_area(poly)
#invalid poly
if abs(p_area) < 1:
continue
if p_area > 0:
#'poly in wrong direction'
if not tag:
tag = True #reversed cases should be ignore
poly = poly[(0, 3, 2, 1), :]
validated_polys.append(poly)
validated_tags.append(tag)
return np.array(validated_polys), np.array(validated_tags)
def draw_img_polys(self, img, polys):
if len(img.shape) == 4:
img = np.squeeze(img, axis=0)
if img.shape[0] == 3:
img = img.transpose((1, 2, 0))
img[:, :, 2] += 123.68
img[:, :, 1] += 116.78
img[:, :, 0] += 103.94
cv2.imwrite("tmp.jpg", img)
img = cv2.imread("tmp.jpg")
for box in polys:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
import random
ino = random.randint(0, 100)
cv2.imwrite("tmp_%d.jpg" % ino, img)
return
def shrink_poly(self, poly, r):
"""
fit a poly inside the origin poly, maybe bugs here...
used for generate the score map
:param poly: the text poly
:param r: r in the paper
:return: the shrinked poly
"""
# shrink ratio
R = 0.3
# find the longer pair
dist0 = np.linalg.norm(poly[0] - poly[1])
dist1 = np.linalg.norm(poly[2] - poly[3])
dist2 = np.linalg.norm(poly[0] - poly[3])
dist3 = np.linalg.norm(poly[1] - poly[2])
if dist0 + dist1 > dist2 + dist3:
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
## p0, p1
theta = np.arctan2((poly[1][1] - poly[0][1]),
(poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
theta = np.arctan2((poly[2][1] - poly[3][1]),
(poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
## p0, p3
theta = np.arctan2((poly[3][0] - poly[0][0]),
(poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
theta = np.arctan2((poly[2][0] - poly[1][0]),
(poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
else:
## p0, p3
# print poly
theta = np.arctan2((poly[3][0] - poly[0][0]),
(poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
theta = np.arctan2((poly[2][0] - poly[1][0]),
(poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
## p0, p1
theta = np.arctan2((poly[1][1] - poly[0][1]),
(poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
theta = np.arctan2((poly[2][1] - poly[3][1]),
(poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
return poly
def generate_quad(self, im_size, polys, tags):
"""
Generate quadrangle.
"""
h, w = im_size
poly_mask = np.zeros((h, w), dtype=np.uint8)
score_map = np.zeros((h, w), dtype=np.uint8)
# (x1, y1, ..., x4, y4, short_edge_norm)
geo_map = np.zeros((h, w, 9), dtype=np.float32)
# mask used during traning, to ignore some hard areas
training_mask = np.ones((h, w), dtype=np.uint8)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
r = [None, None, None, None]
for i in range(4):
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
r[i] = min(dist1, dist2)
# score map
shrinked_poly = self.shrink_poly(
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
cv2.fillPoly(score_map, shrinked_poly, 1)
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
# if the poly is too small, then ignore it during training
poly_h = min(
np.linalg.norm(poly[0] - poly[3]),
np.linalg.norm(poly[1] - poly[2]))
poly_w = min(
np.linalg.norm(poly[0] - poly[1]),
np.linalg.norm(poly[2] - poly[3]))
if min(poly_h, poly_w) < self.min_text_size:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0)
if tag:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0)
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
# geo map.
y_in_poly = xy_in_poly[:, 0]
x_in_poly = xy_in_poly[:, 1]
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
for pno in range(4):
geo_channel_beg = pno * 2
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
x_in_poly - poly[pno, 0]
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
y_in_poly - poly[pno, 1]
geo_map[y_in_poly, x_in_poly, 8] = \
1.0 / max(min(poly_h, poly_w), 1.0)
return score_map, geo_map, training_mask
def crop_area(self,
im,
polys,
tags,
crop_background=False,
max_tries=50):
"""
make random crop from the input image
:param im:
:param polys:
:param tags:
:param crop_background:
:param max_tries:
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if xmax - xmin < self.min_crop_side_ratio * w or \
ymax - ymin < self.min_crop_side_ratio * h:
# area too small
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
& (polys[:, :, 0] <= xmax)\
& (polys[:, :, 1] >= ymin)\
& (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = []
tags = []
return im, polys, tags
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags
return im, polys, tags
def crop_background_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
im, text_polys, text_tags, crop_background=True)
if len(text_polys) > 0:
return None
# pad and resize image
input_size = self.input_size
im, ratio = self.preprocess(im)
score_map = np.zeros((input_size, input_size), dtype=np.float32)
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
training_mask = np.ones((input_size, input_size), dtype=np.float32)
return im, score_map, geo_map, training_mask
def crop_foreground_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
im, text_polys, text_tags, crop_background=False)
if text_polys.shape[0] == 0:
return None
#continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
# pad and resize image
input_size = self.input_size
im, ratio = self.preprocess(im)
text_polys[:, :, 0] *= ratio
text_polys[:, :, 1] *= ratio
_, _, new_h, new_w = im.shape
# print(im.shape)
# self.draw_img_polys(im, text_polys)
score_map, geo_map, training_mask = self.generate_quad(
(new_h, new_w), text_polys, text_tags)
return im, score_map, geo_map, training_mask
def __call__(self, data):
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
if im is None:
return None
if text_polys.shape[0] == 0:
return None
#add rotate cases
if np.random.rand() < 0.5:
im, text_polys = self.rotate_im_poly(im, text_polys)
h, w, _ = im.shape
text_polys, text_tags = self.check_and_validate_polys(text_polys,
text_tags, h, w)
if text_polys.shape[0] == 0:
return None
# random scale this image
rd_scale = np.random.choice(self.random_scale)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
if np.random.rand() < self.background_ratio:
outs = self.crop_background_infor(im, text_polys, text_tags)
else:
outs = self.crop_foreground_infor(im, text_polys, text_tags)
if outs is None:
return None
im, score_map, geo_map, training_mask = outs
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
geo_map = np.swapaxes(geo_map, 1, 2)
geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
training_mask = training_mask[np.newaxis, ::4, ::4]
training_mask = training_mask.astype(np.float32)
data['image'] = im[0]
data['score_map'] = score_map
data['geo_map'] = geo_map
data['training_mask'] = training_mask
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
return data
\ No newline at end of file
......@@ -52,6 +52,7 @@ class DetLabelEncode(object):
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
......@@ -70,6 +71,17 @@ class DetLabelEncode(object):
rect[3] = pts[np.argmax(diff)]
return rect
def expand_points_num(self, boxes):
max_points_num = 0
for box in boxes:
if len(box) > max_points_num:
max_points_num = len(box)
ex_boxes = []
for box in boxes:
ex_box = box + [box[-1]] * (max_points_num - len(box))
ex_boxes.append(ex_box)
return ex_boxes
class BaseRecLabelEncode(object):
""" Convert between text-label and text-index """
......@@ -79,15 +91,17 @@ class BaseRecLabelEncode(object):
character_dict_path=None,
character_type='ch',
use_space_char=False):
support_character_type = ['ch', 'en', 'en_sensitive']
support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, self.character_str)
support_character_type, character_type)
self.max_text_len = max_text_length
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type == "ch":
elif character_type in ["ch", "french", "german", "japan", "korean"]:
self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
with open(character_dict_path, "rb") as fin:
......
......@@ -120,26 +120,37 @@ class DetResizeForTest(object):
if 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
if 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.resize_type == 0:
img, shape = self.resize_image_type0(img)
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
img, shape = self.resize_image_type1(img)
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = shape
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, np.array([ori_h, ori_w])
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
......@@ -182,4 +193,31 @@ class DetResizeForTest(object):
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
return img, np.array([h, w])
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
This diff is collapsed.
This diff is collapsed.
......@@ -18,6 +18,8 @@ import copy
def build_loss(config):
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
# rec loss
from .rec_ctc_loss import CTCLoss
......@@ -25,7 +27,7 @@ def build_loss(config):
# cls loss
from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss']
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
config = copy.deepcopy(config)
module_name = config.pop('name')
......
This diff is collapsed.
This diff is collapsed.
......@@ -19,6 +19,7 @@ def build_backbone(config, model_type):
if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment