Commit 19d66e62 authored by zhoujun's avatar zhoujun
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into py_inference_doc

parents 204ab814 055f207f
...@@ -271,6 +271,59 @@ im_show.save('result.jpg') ...@@ -271,6 +271,59 @@ im_show.save('result.jpg')
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
``` ```
### Use web images or numpy array as input
1. Web image
Use by code
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# show result
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
Use by command line
```bash
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
```
2. Numpy array
Support numpy array as input only when used by code
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line
result = ocr.ocr(img_path, cls=True)
for line in result:
print(line)
# show result
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
## Parameter Description ## Parameter Description
| Parameter | Description | Default value | | Parameter | Description | Default value |
...@@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ ...@@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 | | max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt | | rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
| use_space_char | Whether to recognize spaces | TRUE | | use_space_char | Whether to recognize spaces | TRUE |
| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 |
| use_angle_cls | Whether to load classification model | FALSE | | use_angle_cls | Whether to load classification model | FALSE |
| cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None | | cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
| cls_image_shape | image shape of classification algorithm | "3,48,192" | | cls_image_shape | image shape of classification algorithm | "3,48,192" |
...@@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ ...@@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch | | lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
| det | Enable detction when `ppocr.ocr` func exec | TRUE | | det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec | FALSE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
...@@ -26,17 +26,50 @@ import requests ...@@ -26,17 +26,50 @@ import requests
from tqdm import tqdm from tqdm import tqdm
from tools.infer import predict_system from tools.infer import predict_system
from ppocr.utils.utility import initial_logger from ppocr.utils.logging import get_logger
logger = initial_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
model_params = { model_urls = {
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', 'det':
'rec': 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar',
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', 'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/ic15_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
}
},
'cls':
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar'
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
...@@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path): ...@@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path):
progress_bar.update(len(data)) progress_bar.update(len(data))
file.write(data) file.write(data)
progress_bar.close() progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("ERROR, something went wrong") logger.error("Something went wrong while downloading models")
sys.exit(0) sys.exit(0)
...@@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url): ...@@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url):
# using custom model # using custom model
if not os.path.exists(os.path.join( if not os.path.exists(os.path.join(
model_storage_directory, 'model')) or not os.path.exists( model_storage_directory, 'model')) or not os.path.exists(
os.path.join(model_storage_directory, 'params')): os.path.join(model_storage_directory, 'params')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path)) print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True) os.makedirs(model_storage_directory, exist_ok=True)
...@@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url): ...@@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path) os.remove(tmp_path)
def parse_args(): def parse_args(mMain=True, add_help=True):
import argparse import argparse
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() if mMain:
# params for prediction engine parser = argparse.ArgumentParser(add_help=add_help)
parser.add_argument("--use_gpu", type=str2bool, default=True) # params for prediction engine
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str) # params for text detector
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_model_dir", type=str, default=None) parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_max_side_len", type=float, default=960) parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
# DB parmas parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) # DB parmas
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
# EAST parmas parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) # EAST parmas
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
# params for text recognizer parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None) # params for text recognizer
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument( parser.add_argument("--rec_batch_num", type=int, default=30)
"--rec_char_dict_path", parser.add_argument("--max_text_length", type=int, default=25)
type=str, parser.add_argument("--rec_char_dict_path", type=str, default=None)
default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--drop_score", type=float, default=0.5)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
# params for text classifier
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--use_zero_copy_run", type=bool, default=False) parser.add_argument("--label_list", type=list, default=['0', '180'])
return parser.parse_args() parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args()
else:
return argparse.Namespace(use_gpu=True,
ir_optim=True,
use_tensorrt=False,
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=2.0,
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=30,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=30,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False
)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args() postprocess_params = parse_args(mMain=False, add_help=False)
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
if postprocess_params.rec_char_dict_path is None:
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
if postprocess_params.rec_model_dir is None: if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') postprocess_params.rec_model_dir = os.path.join(
BASE_DIR, 'rec/{}'.format(lang))
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params) print(postprocess_params)
# download model # download model
maybe_download(postprocess_params.det_model_dir, model_params['det']) maybe_download(postprocess_params.det_model_dir, model_urls['det'])
maybe_download(postprocess_params.rec_model_dir, model_params['rec']) maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
...@@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True): def ocr(self, img, det=True, rec=True, cls=False):
""" """
ocr with paddleocr ocr with paddleocr
args: args:
...@@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem):
rec: use text recognition or not, if false, only det will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True
""" """
assert isinstance(img, (np.ndarray, list, str)) assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false')
exit(0)
self.use_angle_cls = cls
if isinstance(img, str): if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img image_file = img
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem):
if img is None: if img is None:
logger.error("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
return None return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
...@@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls:
img, cls_res, elapse = self.text_classifier(img)
if not rec:
return cls_res
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img)
return rec_res return rec_res
def main(): def main():
# for com # for cmd
args = parse_args() args = parse_args(mMain=True)
image_file_list = get_image_file_list(args.image_dir) image_dir = args.image_dir
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0: if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir)) logger.error('no images find in {}'.format(args.image_dir))
return return
ocr_engine = PaddleOCR()
ocr_engine = PaddleOCR(**(args.__dict__))
for img_path in image_file_list: for img_path in image_file_list:
print(img_path) logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) result = ocr_engine.ocr(img_path,
for line in result: det=args.det,
print(line) rec=args.rec,
\ No newline at end of file cls=args.use_angle_cls)
if result is not None:
for line in result:
logger.info(line)
...@@ -26,6 +26,9 @@ from .randaugment import RandAugment ...@@ -26,6 +26,9 @@ from .randaugment import RandAugment
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
from .east_process import *
from .sast_process import *
def transform(data, ops=None): def transform(data, ops=None):
""" transform """ """ transform """
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import cv2
import numpy as np
import json
import sys
import os
__all__ = ['EASTProcessTrain']
class EASTProcessTrain(object):
def __init__(self,
image_shape = [512, 512],
background_ratio = 0.125,
min_crop_side_ratio = 0.1,
min_text_size = 10,
**kwargs):
self.input_size = image_shape[1]
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
self.background_ratio = background_ratio
self.min_crop_side_ratio = min_crop_side_ratio
self.min_text_size = min_text_size
def preprocess(self, im):
input_size = self.input_size
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
# im = im[:, :, ::-1].astype(np.float32)
im = im / 255
im -= img_mean
im /= img_std
new_h, new_w, _ = im.shape
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
im_padded[:new_h, :new_w, :] = im
im_padded = im_padded.transpose((2, 0, 1))
im_padded = im_padded[np.newaxis, :]
return im_padded, im_scale
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
if 0.333 < rand_degree_ratio < 0.666:
rand_degree_cnt = 2
elif rand_degree_ratio > 0.666:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4):
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx)\
- math.sin(rot_angle) * (sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx)\
+ math.cos(rot_angle) * (sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
dst_polys = np.array(dst_polys, dtype=np.float32)
return dst_im, dst_polys
def polygon_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def check_and_validate_polys(self, polys, tags, img_height, img_width):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
h, w = img_height, img_width
if polys.shape[0] == 0:
return polys
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
for poly, tag in zip(polys, tags):
p_area = self.polygon_area(poly)
#invalid poly
if abs(p_area) < 1:
continue
if p_area > 0:
#'poly in wrong direction'
if not tag:
tag = True #reversed cases should be ignore
poly = poly[(0, 3, 2, 1), :]
validated_polys.append(poly)
validated_tags.append(tag)
return np.array(validated_polys), np.array(validated_tags)
def draw_img_polys(self, img, polys):
if len(img.shape) == 4:
img = np.squeeze(img, axis=0)
if img.shape[0] == 3:
img = img.transpose((1, 2, 0))
img[:, :, 2] += 123.68
img[:, :, 1] += 116.78
img[:, :, 0] += 103.94
cv2.imwrite("tmp.jpg", img)
img = cv2.imread("tmp.jpg")
for box in polys:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
import random
ino = random.randint(0, 100)
cv2.imwrite("tmp_%d.jpg" % ino, img)
return
def shrink_poly(self, poly, r):
"""
fit a poly inside the origin poly, maybe bugs here...
used for generate the score map
:param poly: the text poly
:param r: r in the paper
:return: the shrinked poly
"""
# shrink ratio
R = 0.3
# find the longer pair
dist0 = np.linalg.norm(poly[0] - poly[1])
dist1 = np.linalg.norm(poly[2] - poly[3])
dist2 = np.linalg.norm(poly[0] - poly[3])
dist3 = np.linalg.norm(poly[1] - poly[2])
if dist0 + dist1 > dist2 + dist3:
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
## p0, p1
theta = np.arctan2((poly[1][1] - poly[0][1]),
(poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
theta = np.arctan2((poly[2][1] - poly[3][1]),
(poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
## p0, p3
theta = np.arctan2((poly[3][0] - poly[0][0]),
(poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
theta = np.arctan2((poly[2][0] - poly[1][0]),
(poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
else:
## p0, p3
# print poly
theta = np.arctan2((poly[3][0] - poly[0][0]),
(poly[3][1] - poly[0][1]))
poly[0][0] += R * r[0] * np.sin(theta)
poly[0][1] += R * r[0] * np.cos(theta)
poly[3][0] -= R * r[3] * np.sin(theta)
poly[3][1] -= R * r[3] * np.cos(theta)
## p1, p2
theta = np.arctan2((poly[2][0] - poly[1][0]),
(poly[2][1] - poly[1][1]))
poly[1][0] += R * r[1] * np.sin(theta)
poly[1][1] += R * r[1] * np.cos(theta)
poly[2][0] -= R * r[2] * np.sin(theta)
poly[2][1] -= R * r[2] * np.cos(theta)
## p0, p1
theta = np.arctan2((poly[1][1] - poly[0][1]),
(poly[1][0] - poly[0][0]))
poly[0][0] += R * r[0] * np.cos(theta)
poly[0][1] += R * r[0] * np.sin(theta)
poly[1][0] -= R * r[1] * np.cos(theta)
poly[1][1] -= R * r[1] * np.sin(theta)
## p2, p3
theta = np.arctan2((poly[2][1] - poly[3][1]),
(poly[2][0] - poly[3][0]))
poly[3][0] += R * r[3] * np.cos(theta)
poly[3][1] += R * r[3] * np.sin(theta)
poly[2][0] -= R * r[2] * np.cos(theta)
poly[2][1] -= R * r[2] * np.sin(theta)
return poly
def generate_quad(self, im_size, polys, tags):
"""
Generate quadrangle.
"""
h, w = im_size
poly_mask = np.zeros((h, w), dtype=np.uint8)
score_map = np.zeros((h, w), dtype=np.uint8)
# (x1, y1, ..., x4, y4, short_edge_norm)
geo_map = np.zeros((h, w, 9), dtype=np.float32)
# mask used during traning, to ignore some hard areas
training_mask = np.ones((h, w), dtype=np.uint8)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
r = [None, None, None, None]
for i in range(4):
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
r[i] = min(dist1, dist2)
# score map
shrinked_poly = self.shrink_poly(
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
cv2.fillPoly(score_map, shrinked_poly, 1)
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
# if the poly is too small, then ignore it during training
poly_h = min(
np.linalg.norm(poly[0] - poly[3]),
np.linalg.norm(poly[1] - poly[2]))
poly_w = min(
np.linalg.norm(poly[0] - poly[1]),
np.linalg.norm(poly[2] - poly[3]))
if min(poly_h, poly_w) < self.min_text_size:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0)
if tag:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0)
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
# geo map.
y_in_poly = xy_in_poly[:, 0]
x_in_poly = xy_in_poly[:, 1]
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
for pno in range(4):
geo_channel_beg = pno * 2
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
x_in_poly - poly[pno, 0]
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
y_in_poly - poly[pno, 1]
geo_map[y_in_poly, x_in_poly, 8] = \
1.0 / max(min(poly_h, poly_w), 1.0)
return score_map, geo_map, training_mask
def crop_area(self,
im,
polys,
tags,
crop_background=False,
max_tries=50):
"""
make random crop from the input image
:param im:
:param polys:
:param tags:
:param crop_background:
:param max_tries:
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if xmax - xmin < self.min_crop_side_ratio * w or \
ymax - ymin < self.min_crop_side_ratio * h:
# area too small
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
& (polys[:, :, 0] <= xmax)\
& (polys[:, :, 1] >= ymin)\
& (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = []
tags = []
return im, polys, tags
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags
return im, polys, tags
def crop_background_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
im, text_polys, text_tags, crop_background=True)
if len(text_polys) > 0:
return None
# pad and resize image
input_size = self.input_size
im, ratio = self.preprocess(im)
score_map = np.zeros((input_size, input_size), dtype=np.float32)
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
training_mask = np.ones((input_size, input_size), dtype=np.float32)
return im, score_map, geo_map, training_mask
def crop_foreground_infor(self, im, text_polys, text_tags):
im, text_polys, text_tags = self.crop_area(
im, text_polys, text_tags, crop_background=False)
if text_polys.shape[0] == 0:
return None
#continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
# pad and resize image
input_size = self.input_size
im, ratio = self.preprocess(im)
text_polys[:, :, 0] *= ratio
text_polys[:, :, 1] *= ratio
_, _, new_h, new_w = im.shape
# print(im.shape)
# self.draw_img_polys(im, text_polys)
score_map, geo_map, training_mask = self.generate_quad(
(new_h, new_w), text_polys, text_tags)
return im, score_map, geo_map, training_mask
def __call__(self, data):
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
if im is None:
return None
if text_polys.shape[0] == 0:
return None
#add rotate cases
if np.random.rand() < 0.5:
im, text_polys = self.rotate_im_poly(im, text_polys)
h, w, _ = im.shape
text_polys, text_tags = self.check_and_validate_polys(text_polys,
text_tags, h, w)
if text_polys.shape[0] == 0:
return None
# random scale this image
rd_scale = np.random.choice(self.random_scale)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
if np.random.rand() < self.background_ratio:
outs = self.crop_background_infor(im, text_polys, text_tags)
else:
outs = self.crop_foreground_infor(im, text_polys, text_tags)
if outs is None:
return None
im, score_map, geo_map, training_mask = outs
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
geo_map = np.swapaxes(geo_map, 1, 2)
geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
training_mask = training_mask[np.newaxis, ::4, ::4]
training_mask = training_mask.astype(np.float32)
data['image'] = im[0]
data['score_map'] = score_map
data['geo_map'] = geo_map
data['training_mask'] = training_mask
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
return data
\ No newline at end of file
...@@ -52,6 +52,7 @@ class DetLabelEncode(object): ...@@ -52,6 +52,7 @@ class DetLabelEncode(object):
txt_tags.append(True) txt_tags.append(True)
else: else:
txt_tags.append(False) txt_tags.append(False)
boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32) boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool) txt_tags = np.array(txt_tags, dtype=np.bool)
...@@ -70,6 +71,17 @@ class DetLabelEncode(object): ...@@ -70,6 +71,17 @@ class DetLabelEncode(object):
rect[3] = pts[np.argmax(diff)] rect[3] = pts[np.argmax(diff)]
return rect return rect
def expand_points_num(self, boxes):
max_points_num = 0
for box in boxes:
if len(box) > max_points_num:
max_points_num = len(box)
ex_boxes = []
for box in boxes:
ex_box = box + [box[-1]] * (max_points_num - len(box))
ex_boxes.append(ex_box)
return ex_boxes
class BaseRecLabelEncode(object): class BaseRecLabelEncode(object):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
...@@ -79,15 +91,17 @@ class BaseRecLabelEncode(object): ...@@ -79,15 +91,17 @@ class BaseRecLabelEncode(object):
character_dict_path=None, character_dict_path=None,
character_type='ch', character_type='ch',
use_space_char=False): use_space_char=False):
support_character_type = ['ch', 'en', 'en_sensitive'] support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, self.character_str) support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type == "ch": elif character_type in ["ch", "french", "german", "japan", "korean"]:
self.character_str = "" self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
......
...@@ -120,26 +120,37 @@ class DetResizeForTest(object): ...@@ -120,26 +120,37 @@ class DetResizeForTest(object):
if 'limit_side_len' in kwargs: if 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len'] self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min') self.limit_type = kwargs.get('limit_type', 'min')
if 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else: else:
self.limit_side_len = 736 self.limit_side_len = 736
self.limit_type = 'min' self.limit_type = 'min'
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
src_h, src_w, _ = img.shape
if self.resize_type == 0: if self.resize_type == 0:
img, shape = self.resize_image_type0(img) # img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else: else:
img, shape = self.resize_image_type1(img) # img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img data['image'] = img
data['shape'] = shape data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data return data
def resize_image_type1(self, img): def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c) ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h))) img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, np.array([ori_h, ori_w]) # return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img): def resize_image_type0(self, img):
""" """
...@@ -182,4 +193,31 @@ class DetResizeForTest(object): ...@@ -182,4 +193,31 @@ class DetResizeForTest(object):
except: except:
print(img.shape, resize_w, resize_h) print(img.shape, resize_w, resize_h)
sys.exit(0) sys.exit(0)
return img, np.array([h, w]) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import cv2
import numpy as np
import json
import sys
import os
__all__ = ['SASTProcessTrain']
class SASTProcessTrain(object):
def __init__(self,
image_shape = [512, 512],
min_crop_size = 24,
min_crop_side_ratio = 0.3,
min_text_size = 10,
max_text_size = 512,
**kwargs):
self.input_size = image_shape[1]
self.min_crop_size = min_crop_size
self.min_crop_side_ratio = min_crop_side_ratio
self.min_text_size = min_text_size
self.max_text_size = max_text_size
def quad_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
]
return np.sum(edge) / 2.
def gen_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if True:
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
(h, w) = xxx_todo_changeme
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
hv_tags = []
for poly, tag in zip(polys, tags):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
print('invalid poly')
continue
if p_area > 0:
if tag == False:
print('poly in wrong direction')
tag = True # reversed cases should be ignore
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
quad = quad[(0, 3, 2, 1), :]
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2])
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
hv_tag = 1
if len_w * 2.0 < len_h:
hv_tag = 0
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25):
"""
make random crop from the input image
:param im:
:param polys:
:param tags:
:param crop_background:
:param max_tries: 50 -> 25
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w: maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h: maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags, hv_tags
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
# ymax - ymin < ARGS.min_crop_side_ratio * h:
if xmax - xmin < self.min_crop_size or \
ymax - ymin < self.min_crop_size:
# area too small
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
else:
continue
im = im[ymin: ymax + 1, xmin: xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags, hv_tags
return im, polys, tags, hv_tags
def generate_direction_map(self, poly_quads, direction_map):
"""
"""
width_list = []
height_list = []
for quad in poly_quads:
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
for quad in poly_quads:
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)]))
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label)
return direction_map
def calculate_average_height(self, poly_quads):
"""
"""
height_list = []
for quad in poly_quads:
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
def generate_tcl_label(self, hw, polys, tags, ds_ratio,
tcl_ratio=0.3, shrink_ratio_of_width=0.15):
"""
Generate polygon.
"""
h, w = hw
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
score_map = np.zeros((h, w,), dtype=np.float32)
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones((h, w,), dtype=np.float32)
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
continue
if tag:
# continue
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly)
# stcl map
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio)
# generate tcl map
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
# generate tbo map
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map)
return score_map, tbo_map, training_mask
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
"""
Generate tcl map, tvo map and tbo map.
"""
h, w = hw
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
poly_mask = np.zeros((h, w), dtype=np.float32)
tvo_map = np.ones((9, h, w), dtype=np.float32)
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
# tco map
tco_map = np.ones((3, h, w), dtype=np.float32)
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
for poly, poly_tag in zip(polys, tags):
if poly_tag == True:
continue
# adjust point order for vertical poly
poly = self.adjust_point(poly)
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
# generate tcl map and text, 128 * 128
tcl_poly = self.poly2tcl(poly, tcl_ratio)
# generate poly_tv_xy_map
for idx in range(4):
cv2.fillPoly(poly_tv_xy_map[2 * idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(min(max(min_area_quad[idx, 0], 0), w)))
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(min(max(min_area_quad[idx, 1], 0), h)))
# generate poly_tc_xy_map
for idx in range(2):
cv2.fillPoly(poly_tc_xy_map[idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx]))
# generate poly_short_edge_map
cv2.fillPoly(poly_short_edge_map,
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
# generate poly_mask and training_mask
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1)
tvo_map *= poly_mask
tvo_map[:8] -= poly_tv_xy_map
tvo_map[-1] /= poly_short_edge_map
tvo_map = tvo_map.transpose((1, 2, 0))
tco_map *= poly_mask
tco_map[:2] -= poly_tc_xy_map
tco_map[-1] /= poly_short_edge_map
tco_map = tco_map.transpose((1, 2, 0))
return tvo_map, tco_map
def adjust_point(self, poly):
"""
adjust point order.
"""
point_num = poly.shape[0]
if point_num == 4:
len_1 = np.linalg.norm(poly[0] - poly[1])
len_2 = np.linalg.norm(poly[1] - poly[2])
len_3 = np.linalg.norm(poly[2] - poly[3])
len_4 = np.linalg.norm(poly[3] - poly[0])
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
poly = poly[[1, 2, 3, 0], :]
elif point_num > 4:
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
index = list(range(1, point_num)) + [0]
poly = poly[np.array(index), :]
return poly
def gen_min_area_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if point_num == 4:
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad, center_point
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0):
"""
shrink poly with given length.
"""
upper_edge_list = []
def get_cut_info(edge_len_list, cut_len):
for idx, edge_len in enumerate(edge_len_list):
cut_len -= edge_len
if cut_len <= 0.000001:
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
return idx, ratio
for quad in quads:
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
out_quad_list.append(quads[idx])
out_quad_list.append(right_quad)
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
def vector_angle(self, A, B):
"""
Calculate the angle between vector AB and x-axis positive direction.
"""
AB = np.array([B[1] - A[1], B[0] - A[0]])
return np.arctan2(*AB)
def theta_line_cross_point(self, theta, point):
"""
Calculate the line through given point and angle in ax + by + c =0 form.
"""
x, y = point
cos = np.cos(theta)
sin = np.sin(theta)
return [sin, -cos, cos * y - sin * x]
def line_cross_two_point(self, A, B):
"""
Calculate the line through given point A and B in ax + by + c =0 form.
"""
angle = self.vector_angle(A, B)
return self.theta_line_cross_point(angle, A)
def average_angle(self, poly):
"""
Calculate the average angle between left and right edge in given poly.
"""
p0, p1, p2, p3 = poly
angle30 = self.vector_angle(p3, p0)
angle21 = self.vector_angle(p2, p1)
return (angle30 + angle21) / 2
def line_cross_point(self, line1, line2):
"""
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
"""
a1, b1, c1 = line1
a2, b2, c2 = line2
d = a1 * b2 - a2 * b1
if d == 0:
#print("line1", line1)
#print("line2", line2)
print('Cross point does not exist')
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
y = (a2 * c1 - a1 * c2) / d
return np.array([x, y], dtype=np.float32)
def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
"""
Generate tbo_map for give quad.
"""
# upper and lower line function: ax + by + c = 0;
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]))
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]))
# average angle of left and right line.
angle = self.average_angle(quad)
xy_in_poly = np.argwhere(tcl_mask == 1)
for y, x in xy_in_poly:
point = (x, y)
line = self.theta_line_cross_point(angle, point)
cross_point_upper = self.line_cross_point(up_line, line)
cross_point_lower = self.line_cross_point(lower_line, line)
##FIX, offset reverse
upper_offset_x, upper_offset_y = cross_point_upper - point
lower_offset_x, lower_offset_y = cross_point_lower - point
tbo_map[y, x, 0] = upper_offset_y
tbo_map[y, x, 1] = upper_offset_x
tbo_map[y, x, 2] = lower_offset_y
tbo_map[y, x, 3] = lower_offset_x
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
return tbo_map
def poly2quads(self, poly):
"""
Split poly into quads.
"""
quad_list = []
point_num = poly.shape[0]
# point pair
point_pair_list = []
for idx in range(point_num // 2):
point_pair = [poly[idx], poly[point_num - 1 - idx]]
point_pair_list.append(point_pair)
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]])
return np.array(quad_list)
def __call__(self, data):
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
if im is None:
return None
if text_polys.shape[0] == 0:
return None
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
if text_polys.shape[0] == 0:
return None
#set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales)
if np.random.rand() < 0.5:
asp_scale = 1.0 / asp_scale
asp_scale = math.sqrt(asp_scale)
asp_wx = asp_scale
asp_hy = 1.0 / asp_scale
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy
h, w, _ = im.shape
if max(h, w) > 2048:
rd_scale = 2048.0 / max(h, w)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
h, w, _ = im.shape
if min(h, w) < 16:
return None
#no background
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
text_polys, text_tags, hv_tags, crop_background=False)
if text_polys.shape[0] == 0:
return None
#continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
if (new_h is None) or (new_w is None):
return None
#resize image
std_ratio = float(self.input_size) / max(new_w, new_h)
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale
#add gaussian blur
if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1
ks = int(ks/2)*2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
#add brighter
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
#add darker
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# Padding the im to [input_size, input_size]
new_h, new_w, _ = im.shape
if min(new_w, new_h) < self.input_size * 0.5:
return None
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255
# Random the start position
del_h = self.input_size - new_h
del_w = self.input_size - new_w
sh, sw = 0, 0
if del_h > 1:
sh = int(np.random.rand() * del_h)
if del_w > 1:
sw = int(np.random.rand() * del_w)
# Padding
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size),
text_polys, text_tags, 0.25)
# SAST head
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25)
# print("test--------tvo_map shape:", tvo_map.shape)
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
im_padded[:, :, 2] /= (255.0 * 0.229)
im_padded[:, :, 1] /= (255.0 * 0.224)
im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1))
data['image'] = im_padded[::-1, :, :]
data['score_map'] = score_map[np.newaxis, :, :]
data['border_map'] = border_map.transpose((2, 0, 1))
data['training_mask'] = training_mask[np.newaxis, :, :]
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
data['tco_map'] = tco_map.transpose((2, 0, 1))
return data
\ No newline at end of file
...@@ -32,12 +32,10 @@ class SimpleDataSet(Dataset): ...@@ -32,12 +32,10 @@ class SimpleDataSet(Dataset):
self.delimiter = dataset_config.get('delimiter', '\t') self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
if data_source_num == 1: ratio_list = dataset_config.get("ratio_list", [1.0])
ratio_list = [1.0] if isinstance(ratio_list, (float, int)):
else: ratio_list = [float(ratio_list)] * len(data_source_num)
ratio_list = dataset_config.pop('ratio_list')
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
assert len( assert len(
ratio_list ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
...@@ -45,62 +43,32 @@ class SimpleDataSet(Dataset): ...@@ -45,62 +43,32 @@ class SimpleDataSet(Dataset):
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines_list, data_num_list = self.get_image_info_list( self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
label_file_list) self.data_idx_order_list = list(range(len(self.data_lines)))
self.data_idx_order_list = self.dataset_traversal( if mode.lower() == "train":
data_num_list, ratio_list, batch_size) self.shuffle_data_random()
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
def get_image_info_list(self, file_list): def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines_list = [] data_lines = []
data_num_list = [] for idx, file in enumerate(file_list):
for file in file_list:
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
data_lines_list.append(lines) lines = random.sample(lines,
data_num_list.append(len(lines)) round(len(lines) * ratio_list[idx]))
return data_lines_list, data_num_list data_lines.extend(lines)
return data_lines
def dataset_traversal(self, data_num_list, ratio_list, batch_size):
select_num_list = []
dataset_num = len(data_num_list)
for dno in range(dataset_num):
select_num = round(batch_size * ratio_list[dno])
select_num = max(select_num, 1)
select_num_list.append(select_num)
data_idx_order_list = []
cur_index_sets = [0] * dataset_num
while True:
finish_read_num = 0
for dataset_idx in range(dataset_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
finish_read_num += 1
else:
select_num = select_num_list[dataset_idx]
for sno in range(select_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
break
data_idx_order_list.append((dataset_idx, cur_index))
cur_index_sets[dataset_idx] += 1
if finish_read_num == dataset_num:
break
return data_idx_order_list
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
for dno in range(len(self.data_lines_list)): random.shuffle(self.data_lines)
random.shuffle(self.data_lines_list[dno])
return return
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines_list[dataset_idx][file_idx] data_line = self.data_lines[file_idx]
try: try:
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
......
...@@ -18,6 +18,8 @@ import copy ...@@ -18,6 +18,8 @@ import copy
def build_loss(config): def build_loss(config):
# det loss # det loss
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
...@@ -25,7 +27,7 @@ def build_loss(config): ...@@ -25,7 +27,7 @@ def build_loss(config):
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss'] support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .det_basic_loss import DiceLoss
class EASTLoss(nn.Layer):
"""
"""
def __init__(self,
eps=1e-6,
**kwargs):
super(EASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps)
def forward(self, predicts, labels):
l_score, l_geo, l_mask = labels[1:]
f_score = predicts['f_score']
f_geo = predicts['f_geo']
dice_loss = self.dice_loss(f_score, l_score, l_mask)
#smoooth_l1_loss
channels = 8
l_geo_split = paddle.split(
l_geo, num_or_sections=channels + 1, axis=1)
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
smooth_l1 = 0
for i in range(0, channels):
geo_diff = l_geo_split[i] - f_geo_split[i]
abs_geo_diff = paddle.abs(geo_diff)
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
out_loss = l_geo_split[-1] / channels * in_loss * l_score
smooth_l1 += out_loss
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
dice_loss = dice_loss * 0.01
total_loss = dice_loss + smooth_l1_loss
losses = {"loss":total_loss, \
"dice_loss":dice_loss,\
"smooth_l1_loss":smooth_l1_loss}
return losses
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .det_basic_loss import DiceLoss
import paddle.fluid as fluid
import numpy as np
class SASTLoss(nn.Layer):
"""
"""
def __init__(self,
eps=1e-6,
**kwargs):
super(SASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps)
def forward(self, predicts, labels):
"""
tcl_pos: N x 128 x 3
tcl_mask: N x 128 x 1
tcl_label: N x X list or LoDTensor
"""
f_score = predicts['f_score']
f_border = predicts['f_border']
f_tvo = predicts['f_tvo']
f_tco = predicts['f_tco']
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
#score_loss
intersection = paddle.sum(f_score * l_score * l_mask)
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
#border loss
l_border_split, l_border_norm = paddle.split(l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
l_border_norm_split = paddle.expand(x=l_border_norm, shape=border_ex_shape)
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
#tvo_loss
l_tvo_split, l_tvo_norm = paddle.split(l_tvo, num_or_sections=[8, 1], axis=1)
f_tvo_split = f_tvo
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
#
tvo_geo_diff = l_tvo_split - f_tvo_split
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
tvo_sign = abs_tvo_geo_diff < 1.0
tvo_sign = paddle.cast(tvo_sign, dtype='float32')
tvo_sign.stop_gradient = True
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
#tco_loss
l_tco_split, l_tco_norm = paddle.split(l_tco, num_or_sections=[2, 1], axis=1)
f_tco_split = f_tco
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
tco_geo_diff = l_tco_split - f_tco_split
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
tco_sign = abs_tco_geo_diff < 1.0
tco_sign = paddle.cast(tco_sign, dtype='float32')
tco_sign.stop_gradient = True
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
tco_out_loss = l_tco_norm_split * tco_in_loss
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
# total loss
tvo_lw, tco_lw = 1.5, 1.5
score_lw, border_lw = 1.0, 1.0
total_loss = score_loss * score_lw + border_loss * border_lw + \
tvo_loss * tvo_lw + tco_loss * tco_lw
losses = {'loss':total_loss, "score_loss":score_loss,\
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
return losses
\ No newline at end of file
...@@ -19,6 +19,7 @@ def build_backbone(config, model_type): ...@@ -19,6 +19,7 @@ def build_backbone(config, model_type):
if model_type == 'det': if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST'] support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec' or model_type == 'cls': elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet_SAST"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet_SAST(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet_SAST, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
# num_channels = [64, 256, 512,
# 1024] if layers >= 50 else [64, 64, 128, 256]
# num_filters = [64, 128, 256, 512]
num_channels = [64, 256, 512,
1024, 2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = [3, 64]
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out
\ No newline at end of file
...@@ -18,13 +18,15 @@ __all__ = ['build_head'] ...@@ -18,13 +18,15 @@ __all__ = ['build_head']
def build_head(config): def build_head(config):
# det head # det head
from .det_db_head import DBHead from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
# rec head # rec head
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
support_dict = ['DBHead', 'CTCHead', 'ClsHead'] support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format( assert module_name in support_dict, Exception('head only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class EASTHead(nn.Layer):
"""
"""
def __init__(self, in_channels, model_name, **kwargs):
super(EASTHead, self).__init__()
self.model_name = model_name
if self.model_name == "large":
num_outputs = [128, 64, 1, 8]
else:
num_outputs = [64, 32, 1, 8]
self.det_conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=num_outputs[0],
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="det_head1")
self.det_conv2 = ConvBNLayer(
in_channels=num_outputs[0],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="det_head2")
self.score_conv = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name="f_score")
self.geo_conv = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name="f_geo")
def forward(self, x):
f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det)
f_score = F.sigmoid(f_score)
f_geo = self.geo_conv(f_det)
f_geo = (F.sigmoid(f_geo) - 0.5) * 2 * 800
pred = {'f_score': f_score, 'f_geo': f_geo}
return pred
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class SAST_Header1(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(SAST_Header1, self).__init__()
out_channels = [64, 64, 128]
self.score_conv = nn.Sequential(
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_score1'),
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_score2'),
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_score3'),
ConvBNLayer(out_channels[2], 1, 3, 1, act=None, name='f_score4')
)
self.border_conv = nn.Sequential(
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_border1'),
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_border2'),
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_border3'),
ConvBNLayer(out_channels[2], 4, 3, 1, act=None, name='f_border4')
)
def forward(self, x):
f_score = self.score_conv(x)
f_score = F.sigmoid(f_score)
f_border = self.border_conv(x)
return f_score, f_border
class SAST_Header2(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(SAST_Header2, self).__init__()
out_channels = [64, 64, 128]
self.tvo_conv = nn.Sequential(
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tvo1'),
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tvo2'),
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tvo3'),
ConvBNLayer(out_channels[2], 8, 3, 1, act=None, name='f_tvo4')
)
self.tco_conv = nn.Sequential(
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tco1'),
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tco2'),
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tco3'),
ConvBNLayer(out_channels[2], 2, 3, 1, act=None, name='f_tco4')
)
def forward(self, x):
f_tvo = self.tvo_conv(x)
f_tco = self.tco_conv(x)
return f_tvo, f_tco
class SASTHead(nn.Layer):
"""
"""
def __init__(self, in_channels, **kwargs):
super(SASTHead, self).__init__()
self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels)
def forward(self, x):
f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x)
predicts = {}
predicts['f_score'] = f_score
predicts['f_border'] = f_border
predicts['f_tvo'] = f_tvo
predicts['f_tco'] = f_tco
return predicts
\ No newline at end of file
...@@ -16,8 +16,10 @@ __all__ = ['build_neck'] ...@@ -16,8 +16,10 @@ __all__ = ['build_neck']
def build_neck(config): def build_neck(config):
from .db_fpn import DBFPN from .db_fpn import DBFPN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
support_dict = ['DBFPN', 'SequenceEncoder'] support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class DeConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
return x
class EASTFPN(nn.Layer):
def __init__(self, in_channels, model_name, **kwargs):
super(EASTFPN, self).__init__()
self.model_name = model_name
if self.model_name == "large":
self.out_channels = 128
else:
self.out_channels = 64
self.in_channels = in_channels[::-1]
self.h1_conv = ConvBNLayer(
in_channels=self.out_channels+self.in_channels[1],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_h_1")
self.h2_conv = ConvBNLayer(
in_channels=self.out_channels+self.in_channels[2],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_h_2")
self.h3_conv = ConvBNLayer(
in_channels=self.out_channels+self.in_channels[3],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_h_3")
self.g0_deconv = DeConvBNLayer(
in_channels=self.in_channels[0],
out_channels=self.out_channels,
kernel_size=4,
stride=2,
padding=1,
if_act=True,
act='relu',
name="unet_g_0")
self.g1_deconv = DeConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=4,
stride=2,
padding=1,
if_act=True,
act='relu',
name="unet_g_1")
self.g2_deconv = DeConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=4,
stride=2,
padding=1,
if_act=True,
act='relu',
name="unet_g_2")
self.g3_conv = ConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_g_3")
def forward(self, x):
f = x[::-1]
h = f[0]
g = self.g0_deconv(h)
h = paddle.concat([g, f[1]], axis=1)
h = self.h1_conv(h)
g = self.g1_deconv(h)
h = paddle.concat([g, f[2]], axis=1)
h = self.h2_conv(h)
g = self.g2_deconv(h)
h = paddle.concat([g, f[3]], axis=1)
h = self.h3_conv(h)
g = self.g3_conv(h)
return g
\ No newline at end of file
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class DeConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
return x
class FPN_Up_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Up_Fusion, self).__init__()
in_channels = in_channels[::-1]
out_channels = [256, 256, 192, 192, 128]
self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 1, 1, act=None, name='fpn_up_h0')
self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 1, 1, act=None, name='fpn_up_h1')
self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 1, 1, act=None, name='fpn_up_h2')
self.h3_conv = ConvBNLayer(in_channels[3], out_channels[3], 1, 1, act=None, name='fpn_up_h3')
self.h4_conv = ConvBNLayer(in_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_h4')
self.g0_conv = DeConvBNLayer(out_channels[0], out_channels[1], 4, 2, act=None, name='fpn_up_g0')
self.g1_conv = nn.Sequential(
ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_up_g1_1'),
DeConvBNLayer(out_channels[1], out_channels[2], 4, 2, act=None, name='fpn_up_g1_2')
)
self.g2_conv = nn.Sequential(
ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_up_g2_1'),
DeConvBNLayer(out_channels[2], out_channels[3], 4, 2, act=None, name='fpn_up_g2_2')
)
self.g3_conv = nn.Sequential(
ConvBNLayer(out_channels[3], out_channels[3], 3, 1, act='relu', name='fpn_up_g3_1'),
DeConvBNLayer(out_channels[3], out_channels[4], 4, 2, act=None, name='fpn_up_g3_2')
)
self.g4_conv = nn.Sequential(
ConvBNLayer(out_channels[4], out_channels[4], 3, 1, act='relu', name='fpn_up_fusion_1'),
ConvBNLayer(out_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_fusion_2')
)
def _add_relu(self, x1, x2):
x = paddle.add(x=x1, y=x2)
x = F.relu(x)
return x
def forward(self, x):
f = x[2:][::-1]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
h3 = self.h3_conv(f[3])
h4 = self.h4_conv(f[4])
g0 = self.g0_conv(h0)
g1 = self._add_relu(g0, h1)
g1 = self.g1_conv(g1)
g2 = self.g2_conv(self._add_relu(g1, h2))
g3 = self.g3_conv(self._add_relu(g2, h3))
g4 = self.g4_conv(self._add_relu(g3, h4))
return g4
class FPN_Down_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Down_Fusion, self).__init__()
out_channels = [32, 64, 128]
self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 3, 1, act=None, name='fpn_down_h0')
self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 3, 1, act=None, name='fpn_down_h1')
self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 3, 1, act=None, name='fpn_down_h2')
self.g0_conv = ConvBNLayer(out_channels[0], out_channels[1], 3, 2, act=None, name='fpn_down_g0')
self.g1_conv = nn.Sequential(
ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_down_g1_1'),
ConvBNLayer(out_channels[1], out_channels[2], 3, 2, act=None, name='fpn_down_g1_2')
)
self.g2_conv = nn.Sequential(
ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_down_fusion_1'),
ConvBNLayer(out_channels[2], out_channels[2], 1, 1, act=None, name='fpn_down_fusion_2')
)
def forward(self, x):
f = x[:3]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
g0 = self.g0_conv(h0)
g1 = paddle.add(x=g0, y=h1)
g1 = F.relu(g1)
g1 = self.g1_conv(g1)
g2 = paddle.add(x=g1, y=h2)
g2 = F.relu(g2)
g2 = self.g2_conv(g2)
return g2
class Cross_Attention(nn.Layer):
def __init__(self, in_channels):
super(Cross_Attention, self).__init__()
self.theta_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_theta')
self.phi_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_phi')
self.g_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_g')
self.fh_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_weight')
self.fh_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_sc')
self.fv_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_weight')
self.fv_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_sc')
self.f_attn_conv = ConvBNLayer(in_channels * 2, in_channels, 1, 1, act='relu', name='f_attn')
def _cal_fweight(self, f, shape):
f_theta, f_phi, f_g = f
#flatten
f_theta = paddle.transpose(f_theta, [0, 2, 3, 1])
f_theta = paddle.reshape(f_theta, [shape[0] * shape[1], shape[2], 128])
f_phi = paddle.transpose(f_phi, [0, 2, 3, 1])
f_phi = paddle.reshape(f_phi, [shape[0] * shape[1], shape[2], 128])
f_g = paddle.transpose(f_g, [0, 2, 3, 1])
f_g = paddle.reshape(f_g, [shape[0] * shape[1], shape[2], 128])
#correlation
f_attn = paddle.matmul(f_theta, paddle.transpose(f_phi, [0, 2, 1]))
#scale
f_attn = f_attn / (128**0.5)
f_attn = F.softmax(f_attn)
#weighted sum
f_weight = paddle.matmul(f_attn, f_g)
f_weight = paddle.reshape(
f_weight, [shape[0], shape[1], shape[2], 128])
return f_weight
def forward(self, f_common):
f_shape = paddle.shape(f_common)
# print('f_shape: ', f_shape)
f_theta = self.theta_conv(f_common)
f_phi = self.phi_conv(f_common)
f_g = self.g_conv(f_common)
######## horizon ########
fh_weight = self._cal_fweight([f_theta, f_phi, f_g],
[f_shape[0], f_shape[2], f_shape[3]])
fh_weight = paddle.transpose(fh_weight, [0, 3, 1, 2])
fh_weight = self.fh_weight_conv(fh_weight)
#short cut
fh_sc = self.fh_sc_conv(f_common)
f_h = F.relu(fh_weight + fh_sc)
######## vertical ########
fv_theta = paddle.transpose(f_theta, [0, 1, 3, 2])
fv_phi = paddle.transpose(f_phi, [0, 1, 3, 2])
fv_g = paddle.transpose(f_g, [0, 1, 3, 2])
fv_weight = self._cal_fweight([fv_theta, fv_phi, fv_g],
[f_shape[0], f_shape[3], f_shape[2]])
fv_weight = paddle.transpose(fv_weight, [0, 3, 2, 1])
fv_weight = self.fv_weight_conv(fv_weight)
#short cut
fv_sc = self.fv_sc_conv(f_common)
f_v = F.relu(fv_weight + fv_sc)
######## merge ########
f_attn = paddle.concat([f_h, f_v], axis=1)
f_attn = self.f_attn_conv(f_attn)
return f_attn
class SASTFPN(nn.Layer):
def __init__(self, in_channels, with_cab=False, **kwargs):
super(SASTFPN, self).__init__()
self.in_channels = in_channels
self.with_cab = with_cab
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
self.out_channels = 128
self.cross_attention = Cross_Attention(self.out_channels)
def forward(self, x):
#down fpn
f_down = self.FPN_Down_Fusion(x)
#up fpn
f_up = self.FPN_Up_Fusion(x)
#fusion
f_common = paddle.add(x=f_down, y=f_up)
f_common = F.relu(f_common)
if self.with_cab:
# print('enhence f_common with CAB.')
f_common = self.cross_attention(f_common)
return f_common
...@@ -24,11 +24,13 @@ __all__ = ['build_post_process'] ...@@ -24,11 +24,13 @@ __all__ = ['build_post_process']
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
support_dict = [ support_dict = [
'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment