Unverified Commit 4ffb5b62 authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #924 from WenmuZhou/dygraph

Dygraph
parents bc93c549 aad3093a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import numpy as np
from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
__all__ = ['PaddleOCR']
model_params = {
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
'rec':
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
}
SUPPORT_DET_MODEL = ['DB']
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
logger.error("ERROR, something went wrong")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
if not os.path.exists(os.path.join(
model_storage_directory, 'model')) or not os.path.exists(
os.path.join(model_storage_directory, 'params')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
if "model" in member.name:
filename = 'model'
elif "params" in member.name:
filename = 'params'
else:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args():
import argparse
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_max_side_len", type=float, default=960)
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
"--rec_char_dict_path",
type=str,
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
return parser.parse_args()
class PaddleOCR(predict_system.TextSystem):
def __init__(self, **kwargs):
"""
paddleocr package
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params = parse_args()
postprocess_params.__dict__.update(**kwargs)
# init model dir
if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec')
print(postprocess_params)
# download model
maybe_download(postprocess_params.det_model_dir, model_params['det'])
maybe_download(postprocess_params.rec_model_dir, model_params['rec'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
postprocess_params.rec_char_dict_path = Path(
__file__).parent / postprocess_params.rec_char_dict_path
# init det_model and rec_model
super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True):
"""
ocr with paddleocr
args:
img: img for ocr, support ndarray, img_path and list or ndarray
det: use text detection or not, if false, only rec will be exec. default is True
rec: use text recognition or not, if false, only det will be exec. default is True
"""
assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, str):
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if det and rec:
dt_boxes, rec_res = self.__call__(img)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None:
return None
return [box.tolist() for box in dt_boxes]
else:
if not isinstance(img, list):
img = [img]
rec_res, elapse = self.text_recognizer(img)
return rec_res
def main():
# for com
args = parse_args()
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
ocr_engine = PaddleOCR()
for img_path in image_file_list:
print(img_path)
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
for line in result:
print(line)
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import numpy as np
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
__all__ = ['build_dataloader', 'transform', 'create_operators']
def build_dataset(config, global_config):
from ppocr.data.dataset import SimpleDataSet, LMDBDateSet
support_dict = ['SimpleDataSet', 'LMDBDateSet']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
dataset = eval(module_name)(config, global_config)
return dataset
def build_dataloader(config, device, distributed=False, global_config=None):
from ppocr.data.dataset import BatchBalancedDataLoader
config = copy.deepcopy(config)
dataset_config = config['dataset']
_dataset_list = []
file_list = dataset_config.pop('file_list')
if len(file_list) == 1:
ratio_list = [1.0]
else:
ratio_list = dataset_config.pop('ratio_list')
for file in file_list:
dataset_config['file_list'] = file
_dataset = build_dataset(dataset_config, global_config)
_dataset_list.append(_dataset)
data_loader = BatchBalancedDataLoader(_dataset_list, ratio_list,
distributed, device, config['loader'])
return data_loader, _dataset.info_dict
def test_loader():
import time
from tools.program import load_config, ArgsParser
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
place = paddle.CPUPlace()
paddle.disable_static(place)
import time
data_loader, _ = build_dataloader(
config['TRAIN'], place, global_config=config['Global'])
start = time.time()
print(len(data_loader))
for epoch in range(1):
print('epoch {} ****************'.format(epoch))
for i, batch in enumerate(data_loader):
if i > len(data_loader):
break
t = time.time() - start
start = time.time()
print('{}, batch : {} ,time {}'.format(i, len(batch[0]), t))
continue
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import cv2
fig = plt.figure()
# # cv2.imwrite('img.jpg',batch[0].numpy()[0].transpose((1,2,0)))
# # cv2.imwrite('bmap.jpg',batch[1].numpy()[0])
# # cv2.imwrite('bmask.jpg',batch[2].numpy()[0])
# # cv2.imwrite('smap.jpg',batch[3].numpy()[0])
# # cv2.imwrite('smask.jpg',batch[4].numpy()[0])
plt.title('img')
plt.imshow(batch[0].numpy()[0].transpose((1, 2, 0)))
# plt.figure()
# plt.title('bmap')
# plt.imshow(batch[1].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('bmask')
# plt.imshow(batch[2].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('smap')
# plt.imshow(batch[3].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('smask')
# plt.imshow(batch[4].numpy()[0],cmap='Greys')
# plt.show()
# break
if __name__ == '__main__':
test_loader()
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import lmdb
import random
import signal
import paddle
from paddle.io import Dataset, DataLoader, DistributedBatchSampler, BatchSampler
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
class ModeException(Exception):
"""
ModeException
"""
def __init__(self, message='', mode=''):
message += "\nOnly the following 3 modes are supported: " \
"train, valid, test. Given mode is {}".format(mode)
super(ModeException, self).__init__(message)
class SampleNumException(Exception):
"""
SampleNumException
"""
def __init__(self, message='', sample_num=0, batch_size=1):
message += "\nError: The number of the whole data ({}) " \
"is smaller than the batch_size ({}), and drop_last " \
"is turnning on, so nothing will feed in program, " \
"Terminated now. Please reset batch_size to a smaller " \
"number or feed more data!".format(sample_num, batch_size)
super(SampleNumException, self).__init__(message)
def get_file_list(file_list, data_dir, delimiter='\t'):
"""
read label list from file and shuffle the list
Args:
params(dict):
"""
if isinstance(file_list, str):
file_list = [file_list]
data_source_list = []
for file in file_list:
with open(file) as f:
full_lines = [line.strip() for line in f]
for line in full_lines:
try:
img_path, label = line.split(delimiter)
except:
logger = get_logger()
logger.warning('label error in {}'.format(line))
img_path = os.path.join(data_dir, img_path)
data = {'img_path': img_path, 'label': label}
data_source_list.append(data)
return data_source_list
class LMDBDateSet(Dataset):
def __init__(self, config, global_config):
super(LMDBDateSet, self).__init__()
self.data_list = self.load_lmdb_dataset(
config['file_list'], global_config['max_text_length'])
random.shuffle(self.data_list)
self.ops = create_operators(config['transforms'], global_config)
# for rec
character = ''
for op in self.ops:
if hasattr(op, 'character'):
character = getattr(op, 'character')
self.info_dict = {'character': character}
def load_lmdb_dataset(self, data_dir, max_text_length):
self.env = lmdb.open(
data_dir,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot create lmdb from %s' % (data_dir))
exit(0)
filtered_index_list = []
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples
for index in range(self.nSamples):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > max_text_length:
# print(f'The length of the label is longer than max_length: length
# {len(label)}, {label} in dataset {self.root}')
continue
# By default, images containing characters which are not in opt.character are filtered.
# You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
filtered_index_list.append(index)
return filtered_index_list
def print_lmdb_sets_info(self, lmdb_sets):
lmdb_info_strs = []
for dataset_idx in range(len(lmdb_sets)):
tmp_str = " %s:%d," % (lmdb_sets[dataset_idx]['dirpath'],
lmdb_sets[dataset_idx]['num_samples'])
lmdb_info_strs.append(tmp_str)
lmdb_info_strs = ''.join(lmdb_info_strs)
logger = get_logger()
logger.info("DataSummary:" + lmdb_info_strs)
return
def __getitem__(self, idx):
idx = self.data_list[idx]
with self.env.begin(write=False) as txn:
label_key = 'label-%09d'.encode() % idx
label = txn.get(label_key)
if label is not None:
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % idx
imgbuf = txn.get(img_key)
data = {'image': imgbuf, 'label': label}
outs = transform(data, self.ops)
else:
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_list)
class SimpleDataSet(Dataset):
def __init__(self, config, global_config):
super(SimpleDataSet, self).__init__()
delimiter = config.get('delimiter', '\t')
self.data_list = get_file_list(config['file_list'], config['data_dir'],
delimiter)
random.shuffle(self.data_list)
self.ops = create_operators(config['transforms'], global_config)
# for rec
character = ''
for op in self.ops:
if hasattr(op, 'character'):
character = getattr(op, 'character')
self.info_dict = {'character': character}
def __getitem__(self, idx):
data = copy.deepcopy(self.data_list[idx])
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_list)
class BatchBalancedDataLoader(object):
def __init__(self,
dataset_list: list,
ratio_list: list,
distributed,
device,
loader_args: dict):
"""
对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的
:param dataset_list: 数据集列表
:param ratio_list: 比例列表
:param loader_args: dataloader的配置
"""
assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list)
self.dataset_len = 0
self.data_loader_list = []
self.dataloader_iter_list = []
all_batch_size = loader_args.pop('batch_size')
batch_size_list = list(
map(int, [max(1.0, all_batch_size * x) for x in ratio_list]))
remain_num = all_batch_size - sum(batch_size_list)
batch_size_list[np.argmax(ratio_list)] += remain_num
for _dataset, _batch_size in zip(dataset_list, batch_size_list):
if distributed:
batch_sampler_class = DistributedBatchSampler
else:
batch_sampler_class = BatchSampler
batch_sampler = batch_sampler_class(
dataset=_dataset,
batch_size=_batch_size,
shuffle=loader_args['shuffle'],
drop_last=loader_args['drop_last'], )
_data_loader = DataLoader(
dataset=_dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=loader_args['num_workers'],
return_list=True, )
self.data_loader_list.append(_data_loader)
self.dataloader_iter_list.append(iter(_data_loader))
self.dataset_len += len(_dataset)
def __iter__(self):
return self
def __len__(self):
return min([len(x) for x in self.data_loader_list])
def __next__(self):
batch = []
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
try:
_batch_i = next(data_loader_iter)
batch.append(_batch_i)
except StopIteration:
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
_batch_i = next(self.dataloader_iter_list[i])
batch.append(_batch_i)
except ValueError:
pass
if len(batch) > 0:
batch_list = []
batch_item_size = len(batch[0])
for i in range(batch_item_size):
cur_item_list = [batch_i[i] for batch_i in batch]
batch_list.append(paddle.concat(cur_item_list, axis=0))
else:
batch_list = batch[0]
return batch_list
def fill_batch(batch):
"""
2020.09.08: The current paddle version only supports returning data with the same length.
Therefore, fill in the batches with inconsistent lengths.
this method is currently only useful for text detection
"""
keys = list(range(len(batch[0])))
v_max_len_dict = {}
for k in keys:
v_max_len_dict[k] = max([len(item[k]) for item in batch])
for item in batch:
length = []
for k in keys:
v = item[k]
length.append(len(v))
assert isinstance(v, np.ndarray)
if len(v) == v_max_len_dict[k]:
continue
try:
tmp_shape = [v_max_len_dict[k] - len(v)] + list(v[0].shape)
except:
a = 1
tmp_array = np.zeros(tmp_shape, dtype=v[0].dtype)
new_array = np.concatenate([v, tmp_array])
item[k] = new_array
item.append(length)
return batch
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from .iaa_augment import IaaAugment
from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg
from .operators import *
from .label_ops import *
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import imgaug
import imgaug.augmenters as iaa
class AugmenterBuilder(object):
def __init__(self):
pass
def build(self, args, root=True):
if args is None or len(args) == 0:
return None
elif isinstance(args, list):
if root:
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
else:
return getattr(iaa, args[0])(
*[self.to_tuple_if_list(a) for a in args[1:]])
elif isinstance(args, dict):
cls = getattr(iaa, args['type'])
return cls(**{
k: self.to_tuple_if_list(v)
for k, v in args['args'].items()
})
else:
raise RuntimeError('unknown augmenter arg: ' + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
return tuple(obj)
return obj
class IaaAugment():
def __init__(self, augmenter_args=None, **kwargs):
if augmenter_args is None:
augmenter_args = [{
'type': 'Fliplr',
'args': {
'p': 0.5
}
}, {
'type': 'Affine',
'args': {
'rotate': [-10, 10]
}
}, {
'type': 'Resize',
'args': {
'size': [0.5, 3]
}
}]
self.augmenter = AugmenterBuilder().build(augmenter_args)
def __call__(self, data):
image = data['image']
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
data['image'] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
def may_augment_annotation(self, aug, data, shape):
if aug is None:
return data
line_polys = []
for poly in data['polys']:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data['polys'] = np.array(line_polys)
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(
keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from ppocr.utils.logging import get_logger
class DetLabelEncode(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
import json
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
class BaseRecLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False):
support_character_type = ['ch', 'en', 'en_sensitive']
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, self.character_str)
self.max_text_len = max_text_length
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type == "ch":
self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
elif character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
import string
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) > self.max_text_len:
return None
if self.character_type == "en":
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
# logger = get_logger()
# logger.warning('{} is not in dict'.format(char))
continue
text_list.append(self.dict[char])
if len(text_list) == 0:
return None
return text_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(CTCLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(AttnLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.beg_str = "sos"
self.end_str = "eos"
def add_special_char(self, dict_character):
dict_character = [self.beg_str, self.end_str] + dict_character
return dict_character
def __call__(self, text):
text = self.encode(text)
return text
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
np.seterr(divide='ignore', invalid='ignore')
import pyclipper
from shapely.geometry import Polygon
import sys
import warnings
warnings.simplefilter("ignore")
__all__ = ['MakeBorderMap']
class MakeBorderMap(object):
def __init__(self,
shrink_ratio=0.4,
thresh_min=0.3,
thresh_max=0.7,
**kwargs):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
self.thresh_max = thresh_max
def __call__(self, data: dict) -> dict:
img = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
canvas = np.zeros(img.shape[:2], dtype=np.float32)
mask = np.zeros(img.shape[:2], dtype=np.float32)
for i in range(len(text_polys)):
if ignore_tags[i]:
continue
self.draw_border_map(text_polys[i], canvas, mask=mask)
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
data['threshold_map'] = canvas
data['threshold_mask'] = mask
return data
def draw_border_map(self, polygon, canvas, mask):
polygon = np.array(polygon)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
distance = polygon_shape.area * (
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(distance)[0])
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
xmin = padded_polygon[:, 0].min()
xmax = padded_polygon[:, 0].max()
ymin = padded_polygon[:, 1].min()
ymax = padded_polygon[:, 1].max()
width = xmax - xmin + 1
height = ymax - ymin + 1
polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
np.linspace(
0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
np.linspace(
0, height - 1, num=height).reshape(height, 1), (height, width))
distance_map = np.zeros(
(polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
xmin_valid - xmin:xmax_valid - xmax + width],
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
def _distance(self, xs, ys, point_1, point_2):
'''
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
'''
height, width = xs.shape[:2]
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
point_1[1] - point_2[1])
cosin = (square_distance - square_distance_1 - square_distance_2) / (
2 * np.sqrt(square_distance_1 * square_distance_2))
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
square_distance)
result[cosin <
0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
< 0]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(self, point_1, point_2, result, shrink_ratio):
ex_point_1 = (int(
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
int(
round(point_1[1] + (point_1[1] - point_2[1]) * (
1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_1),
tuple(point_1),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
ex_point_2 = (int(
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
int(
round(point_2[1] + (point_2[1] - point_1[1]) * (
1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_2),
tuple(point_2),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
return ex_point_1, ex_point_2
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
__all__ = ['MakeShrinkMap']
class MakeShrinkMap(object):
r'''
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
'''
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio
def __call__(self, data):
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w = image.shape[:2]
text_polys, ignore_tags = self.validate_polygons(text_polys,
ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
# gt = np.zeros((1, h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
polygon = text_polys[i]
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
if ignore_tags[i] or min(height, width) < self.min_text_size:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
polygon_shape = Polygon(polygon)
distance = polygon_shape.area * (
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in text_polys[i]]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
shrinked = padding.Execute(-distance)
if shrinked == []:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
# cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
data['shrink_map'] = gt
data['shrink_mask'] = mask
return data
def validate_polygons(self, polygons, ignore_tags, h, w):
'''
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
'''
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
for polygon in polygons:
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
for i in range(len(polygons)):
area = self.polygon_area(polygons[i])
if abs(area) < 1:
ignore_tags[i] = True
if area > 0:
polygons[i] = polygons[i][::-1, :]
return polygons, ignore_tags
def polygon_area(self, polygon):
# return cv2.contourArea(polygon.astype(np.float32))
edge = 0
for i in range(polygon.shape[0]):
next_index = (i + 1) % polygon.shape[0]
edge += (polygon[next_index, 0] - polygon[i, 0]) * (
polygon[next_index, 1] - polygon[i, 1])
return edge / 2.
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class DecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class keepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
if self.resize_type == 0:
img, shape = self.resize_image_type0(img)
else:
img, shape = self.resize_image_type1(img)
data['image'] = img
data['shape'] = shape
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, np.array([ori_h, ori_w])
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
else:
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
return img, np.array([h, w])
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
import random
def is_poly_in_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
h, w, _ = im.shape
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = split_regions(h_axis)
w_regions = split_regions(w_axis)
for i in range(max_tries):
if len(w_regions) > 1:
xmin, xmax = region_wise_random_select(w_regions, w)
else:
xmin, xmax = random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = region_wise_random_select(h_regions, h)
else:
ymin, ymax = random_select(h_axis, h)
if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h
class EastRandomCropData(object):
def __init__(self,
size=(640, 640),
max_tries=10,
min_crop_side_ratio=0.1,
keep_ratio=True,
**kwargs):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.keep_ratio = keep_ratio
def __call__(self, data):
img = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
texts = data['texts']
all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(
img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
img.dtype)
padimg[:h, :w] = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(self.size))
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['image'] = img
data['polys'] = np.array(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data
class PSERandomCrop(object):
def __init__(self, size, **kwargs):
self.size = size
def __call__(self, data):
imgs = data['imgs']
h, w = imgs[0].shape[0:2]
th, tw = self.size
if w == tw and h == th:
return imgs
# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
# 文本实例的左上角点
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
tl[tl < 0] = 0
# 文本实例的右下角点
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
br[br < 0] = 0
# 保证选到右下角点时,有足够的距离进行crop
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
for _ in range(50000):
i = random.randint(tl[0], br[0])
j = random.randint(tl[1], br[1])
# 保证shrink_label_map有文本
if imgs[1][i:i + th, j:j + tw].sum() <= 0:
continue
else:
break
else:
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
# return i, j, th, tw
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
else:
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
data['imgs'] = imgs
return data
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
class RecAug(object):
def __init__(self, **kwargsz):
pass
def __call__(self, data):
img = data['image']
img = warp(img, 10)
data['image'] = img
return data
class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_type='ch',
use_tps=False,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_type = character_type
self.use_tps = use_tps
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_type == "ch" and not self.use_tps:
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape)
data['image'] = norm_img
return data
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def flag():
"""
flag
"""
return 1 if random.random() > 0.5000001 else -1
def cvtColor(img):
"""
cvtColor
"""
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
delta = 0.001 * random.random() * flag()
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return new_img
def blur(img):
"""
blur
"""
h, w, _ = img.shape
if h > 10 and w > 10:
return cv2.GaussianBlur(img, (5, 5), 1)
else:
return img
def jitter(img):
"""
jitter
"""
w, h, _ = img.shape
if h > 10 and w > 10:
thres = min(w, h)
s = int(random.random() * thres * 0.01)
src_img = img.copy()
for i in range(s):
img[i:, i:, :] = src_img[:w - i, :h - i, :]
return img
else:
return img
def add_gasuss_noise(image, mean=0, var=0.1):
"""
Gasuss noise
"""
noise = np.random.normal(mean, var**0.5, image.shape)
out = image + 0.5 * noise
out = np.clip(out, 0, 255)
out = np.uint8(out)
return out
def get_crop(image):
"""
random crop
"""
h, w, _ = image.shape
top_min = 1
top_max = 8
top_crop = int(random.randint(top_min, top_max))
top_crop = min(top_crop, h - 1)
crop_img = image.copy()
ratio = random.randint(0, 1)
if ratio:
crop_img = crop_img[top_crop:h, :, :]
else:
crop_img = crop_img[0:h - top_crop, :, :]
return crop_img
class Config:
"""
Config
"""
def __init__(self, ):
self.anglex = random.random() * 30
self.angley = random.random() * 15
self.anglez = random.random() * 10
self.fov = 42
self.r = 0
self.shearx = random.random() * 0.3
self.sheary = random.random() * 0.05
self.borderMode = cv2.BORDER_REPLICATE
def make(self, w, h, ang):
"""
make
"""
self.anglex = random.random() * 5 * flag()
self.angley = random.random() * 5 * flag()
self.anglez = -1 * random.random() * int(ang) * flag()
self.fov = 42
self.r = 0
self.shearx = 0
self.sheary = 0
self.borderMode = cv2.BORDER_REPLICATE
self.w = w
self.h = h
self.perspective = True
self.stretch = True
self.distort = True
self.crop = True
self.affine = False
self.reverse = True
self.noise = True
self.jitter = True
self.blur = True
self.color = True
def rad(x):
"""
rad
"""
return x * np.pi / 180
def get_warpR(config):
"""
get_warpR
"""
anglex, angley, anglez, fov, w, h, r = \
config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
if w > 69 and w < 112:
anglex = anglex * 1.5
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
# Homogeneous coordinate transformation matrix
rx = np.array([[1, 0, 0, 0],
[0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
0,
-np.sin(rad(anglex)),
np.cos(rad(anglex)),
0,
], [0, 0, 0, 1]], np.float32)
ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
[0, 1, 0, 0], [
-np.sin(rad(angley)),
0,
np.cos(rad(angley)),
0,
], [0, 0, 0, 1]], np.float32)
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
[0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
r = rx.dot(ry).dot(rz)
# generate 4 points
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
p3 = np.array([0, h, 0, 0], np.float32) - pcenter
p4 = np.array([w, h, 0, 0], np.float32) - pcenter
dst1 = r.dot(p1)
dst2 = r.dot(p2)
dst3 = r.dot(p3)
dst4 = r.dot(p4)
list_dst = np.array([dst1, dst2, dst3, dst4])
org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
dst = np.zeros((4, 2), np.float32)
# Project onto the image plane
dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
warpR = cv2.getPerspectiveTransform(org, dst)
dst1, dst2, dst3, dst4 = dst
r1 = int(min(dst1[1], dst2[1]))
r2 = int(max(dst3[1], dst4[1]))
c1 = int(min(dst1[0], dst3[0]))
c2 = int(max(dst2[0], dst4[0]))
try:
ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
dx = -c1
dy = -r1
T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
ret = T1.dot(warpR)
except:
ratio = 1.0
T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
ret = T1
return ret, (-r1, -c1), ratio, dst
def get_warpAffine(config):
"""
get_warpAffine
"""
anglez = config.anglez
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
return rz
def warp(img, ang):
"""
warp
"""
h, w, _ = img.shape
config = Config()
config.make(w, h, ang)
new_img = img
prob = 0.4
if config.distort:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_distort(new_img, random.randint(3, 6))
if config.stretch:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_stretch(new_img, random.randint(3, 6))
if config.perspective:
if random.random() <= prob:
new_img = tia_perspective(new_img)
if config.crop:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = get_crop(new_img)
if config.blur:
if random.random() <= prob:
new_img = blur(new_img)
if config.color:
if random.random() <= prob:
new_img = cvtColor(new_img)
if config.jitter:
new_img = jitter(new_img)
if config.noise:
if random.random() <= prob:
new_img = add_gasuss_noise(new_img)
if config.reverse:
if random.random() <= prob:
new_img = 255 - new_img
return new_img
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .augment import tia_perspective, tia_distort, tia_stretch
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from .warp_mls import WarpMLS
def tia_distort(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_stretch(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_perspective(src):
img_h, img_w = src.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class WarpMLS:
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
self.src = src
self.src_pts = src_pts
self.dst_pts = dst_pts
self.pt_count = len(self.dst_pts)
self.dst_w = dst_w
self.dst_h = dst_h
self.trans_ratio = trans_ratio
self.grid_size = 100
self.rdx = np.zeros((self.dst_h, self.dst_w))
self.rdy = np.zeros((self.dst_h, self.dst_w))
@staticmethod
def __bilinear_interp(x, y, v11, v12, v21, v22):
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
(1 - y) + v22 * y) * x
def generate(self):
self.calc_delta()
return self.gen_img()
def calc_delta(self):
w = np.zeros(self.pt_count, dtype=np.float32)
if self.pt_count < 2:
return
i = 0
while 1:
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
i = self.dst_w - 1
elif i >= self.dst_w:
break
j = 0
while 1:
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
j = self.dst_h - 1
elif j >= self.dst_h:
break
sw = 0
swp = np.zeros(2, dtype=np.float32)
swq = np.zeros(2, dtype=np.float32)
new_pt = np.zeros(2, dtype=np.float32)
cur_pt = np.array([i, j], dtype=np.float32)
k = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
break
w[k] = 1. / (
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
sw += w[k]
swp = swp + w[k] * np.array(self.dst_pts[k])
swq = swq + w[k] * np.array(self.src_pts[k])
if k == self.pt_count - 1:
pstar = 1 / sw * swp
qstar = 1 / sw * swq
miu_s = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
miu_s += w[k] * np.sum(pt_i * pt_i)
cur_pt -= pstar
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
tmp_pt *= (w[k] / miu_s)
new_pt += tmp_pt
new_pt += qstar
else:
new_pt = self.src_pts[k]
self.rdx[j, i] = new_pt[0] - i
self.rdy[j, i] = new_pt[1] - j
j += self.grid_size
i += self.grid_size
def gen_img(self):
src_h, src_w = self.src.shape[:2]
dst = np.zeros_like(self.src, dtype=np.float32)
for i in np.arange(0, self.dst_h, self.grid_size):
for j in np.arange(0, self.dst_w, self.grid_size):
ni = i + self.grid_size
nj = j + self.grid_size
w = h = self.grid_size
if ni >= self.dst_h:
ni = self.dst_h - 1
h = ni - i + 1
if nj >= self.dst_w:
nj = self.dst_w - 1
w = nj - j + 1
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self.__bilinear_interp(
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
self.rdx[ni, j], self.rdx[ni, nj])
delta_y = self.__bilinear_interp(
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
self.rdy[ni, j], self.rdy[ni, nj])
nx = j + dj + delta_x * self.trans_ratio
ny = i + di + delta_y * self.trans_ratio
nx = np.clip(nx, 0, src_w - 1)
ny = np.clip(ny, 0, src_h - 1)
nxi = np.array(np.floor(nx), dtype=np.int32)
nyi = np.array(np.floor(ny), dtype=np.int32)
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
if len(self.src.shape) == 3:
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
else:
x = ny - nyi
y = nx - nxi
dst[i:i + h, j:j + w] = self.__bilinear_interp(
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
self.src[nyi1, nxi], self.src[nyi1, nxi1])
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
return dst
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['DetMetric']
from .eval_det_iou import DetectionIoUEvaluator
class DetMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.evaluator = DetectionIoUEvaluator()
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
ignore_tags_batch = batch[3]
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': '',
'ignore': ignore_tag
} for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
# prepare det
det_info_list = [{
'points': det_polyon,
'text': ''
} for det_polyon in pred['points']]
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result)
def get_metric(self):
"""
return metircs {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metircs = self.evaluator.combine_results(self.results)
self.reset()
return metircs
def reset(self):
self.results = [] # clear results
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import Levenshtein
class RecMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
all_num = 0
norm_edit_dis = 0.0
for (pred, pred_conf), (target, _) in zip(preds, labels):
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target))
if pred == target:
correct_num += 1
all_num += 1
# if all_num < 10 and kwargs.get('show_str', False):
# print('{} -> {}'.format(pred, target))
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
return {
'acc': correct_num / all_num,
'norm_edit_dis': 1 - norm_edit_dis / all_num
}
def get_metric(self):
"""
return metircs {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc = self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
__all__ = ['build_metric']
def build_metric(config):
from .DetMetric import DetMetric
from .RecMetric import RecMetric
support_dict = ['DetMetric', 'RecMetric']
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception(
'metric only support {}'.format(support_dict))
module_class = eval(module_name)(**config)
return module_class
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import namedtuple
import numpy as np
from shapely.geometry import Polygon
"""
reference from :
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
"""
class DetectionIoUEvaluator(object):
def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
self.iou_constraint = iou_constraint
self.area_precision_constraint = area_precision_constraint
def evaluate_image(self, gt, pred):
def get_union(pD, pG):
return Polygon(pD).union(Polygon(pG)).area
def get_intersection_over_union(pD, pG):
return get_intersection(pD, pG) / get_union(pD, pG)
def get_intersection(pD, pG):
return Polygon(pD).intersection(Polygon(pG)).area
def compute_ap(confList, matchList, numGtCare):
correct = 0
AP = 0
if len(confList) > 0:
confList = np.array(confList)
matchList = np.array(matchList)
sorted_ind = np.argsort(-confList)
confList = confList[sorted_ind]
matchList = matchList[sorted_ind]
for n in range(len(confList)):
match = matchList[n]
if match:
correct += 1
AP += float(correct) / (n + 1)
if numGtCare > 0:
AP /= numGtCare
return AP
perSampleMetrics = {}
matchedSum = 0
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
numGlobalCareGt = 0
numGlobalCareDet = 0
arrGlobalConfidences = []
arrGlobalMatches = []
recall = 0
precision = 0
hmean = 0
detMatched = 0
iouMat = np.empty([1, 1])
gtPols = []
detPols = []
gtPolPoints = []
detPolPoints = []
# Array of Ground Truth Polygons' keys marked as don't Care
gtDontCarePolsNum = []
# Array of Detected Polygons' matched with a don't Care GT
detDontCarePolsNum = []
pairs = []
detMatchedNums = []
arrSampleConfidences = []
arrSampleMatch = []
evaluationLog = ""
# print(len(gt))
for n in range(len(gt)):
points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore']
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
gtPol = points
gtPols.append(gtPol)
gtPolPoints.append(points)
if dontCare:
gtDontCarePolsNum.append(len(gtPols) - 1)
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
if len(gtDontCarePolsNum) > 0 else "\n")
for n in range(len(pred)):
points = pred[n]['points']
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
detPol = points
detPols.append(detPol)
detPolPoints.append(points)
if len(gtDontCarePolsNum) > 0:
for dontCarePol in gtDontCarePolsNum:
dontCarePol = gtPols[dontCarePol]
intersected_area = get_intersection(dontCarePol, detPol)
pdDimensions = Polygon(detPol).area
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
if (precision > self.area_precision_constraint):
detDontCarePolsNum.append(len(detPols) - 1)
break
evaluationLog += "DET polygons: " + str(len(detPols)) + (
" (" + str(len(detDontCarePolsNum)) + " don't care)\n"
if len(detDontCarePolsNum) > 0 else "\n")
if len(gtPols) > 0 and len(detPols) > 0:
# Calculate IoU and precision matrixs
outputShape = [len(gtPols), len(detPols)]
iouMat = np.empty(outputShape)
gtRectMat = np.zeros(len(gtPols), np.int8)
detRectMat = np.zeros(len(detPols), np.int8)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
pG = gtPols[gtNum]
pD = detPols[detNum]
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
for gtNum in range(len(gtPols)):
for detNum in range(len(detPols)):
if gtRectMat[gtNum] == 0 and detRectMat[
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
if iouMat[gtNum, detNum] > self.iou_constraint:
gtRectMat[gtNum] = 1
detRectMat[detNum] = 1
detMatched += 1
pairs.append({'gt': gtNum, 'det': detNum})
detMatchedNums.append(detNum)
evaluationLog += "Match GT #" + \
str(gtNum) + " with Det #" + str(detNum) + "\n"
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
if numGtCare == 0:
recall = float(1)
precision = float(0) if numDetCare > 0 else float(1)
else:
recall = float(detMatched) / numGtCare
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
hmean = 0 if (precision + recall) == 0 else 2.0 * \
precision * recall / (precision + recall)
matchedSum += detMatched
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
perSampleMetrics = {
'precision': precision,
'recall': recall,
'hmean': hmean,
'pairs': pairs,
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
'gtPolPoints': gtPolPoints,
'detPolPoints': detPolPoints,
'gtCare': numGtCare,
'detCare': numDetCare,
'gtDontCare': gtDontCarePolsNum,
'detDontCare': detDontCarePolsNum,
'detMatched': detMatched,
'evaluationLog': evaluationLog
}
return perSampleMetrics
def combine_results(self, results):
numGlobalCareGt = 0
numGlobalCareDet = 0
matchedSum = 0
for result in results:
numGlobalCareGt += result['gtCare']
numGlobalCareDet += result['detCare']
matchedSum += result['detMatched']
methodRecall = 0 if numGlobalCareGt == 0 else float(
matchedSum) / numGlobalCareGt
methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
'precision': methodPrecision,
'recall': methodRecall,
'hmean': methodHmean
}
return methodMetrics
if __name__ == '__main__':
evaluator = DetectionIoUEvaluator()
gts = [[{
'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
'text': 1234,
'ignore': False,
}, {
'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
'text': 5678,
'ignore': False,
}]]
preds = [[{
'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
'text': 123,
'ignore': False,
}]]
results = []
for gt, pred in zip(gts, preds):
results.append(evaluator.evaluate_image(gt, pred))
metrics = evaluator.combine_results(results)
print(metrics)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from .losses import build_loss
__all__ = ['build_model', 'build_loss']
def build_model(config):
from .architectures import Model
config = copy.deepcopy(config)
module_class = Model(config)
return module_class
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment