Commit 019b16be authored by chenxj's avatar chenxj
Browse files

first commit

parent 3019db46
Pipeline #569 canceled with stages
import argparse
import os
import sys
import platform
import cv2
import numpy as np
import paddle
from PIL import Image, ImageDraw, ImageFont
import math
from paddle import inference
import time
from ppocr.utils.logging import get_logger
def str2bool(v):
return v.lower() in ("true", "t", "1")
def init_args():
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_xpu", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
parser.add_argument("--vis_seg_map", type=str2bool, default=False)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
"--rec_char_dict_path",
type=str,
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument(
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=False)
#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
# multi-process
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--benchmark", type=str2bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/")
parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
else:
model_dir = args.rec_model_dir
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0)
if args.use_onnx:
import onnxruntime as ort
model_file_path = model_dir
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
sess = ort.InferenceSession(model_file_path, providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider'])
return sess, sess.get_inputs()[0], None, None
else:
model_file_path = model_dir + "/inference.pdmodel"
params_file_path = model_dir + "/inference.pdiparams"
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
if not os.path.exists(params_file_path):
raise ValueError("not find params file path {}".format(
params_file_path))
config = inference.Config(model_file_path, params_file_path)
if hasattr(args, 'precision'):
if args.precision == "fp16" and args.use_tensorrt:
precision = inference.PrecisionType.Half
print("fp16 set success!")
elif args.precision == "int8":
precision = inference.PrecisionType.Int8
else:
precision = inference.PrecisionType.Float32
else:
precision = inference.PrecisionType.Float32
if args.use_gpu:
gpu_id = get_infer_gpuid()
if gpu_id is None:
logger.warning(
"GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
)
config.enable_use_gpu(args.gpu_mem, 0)
use_dynamic_shape = True
if mode == "det":
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_92.tmp_0": [1, 120, 20, 20],
"conv2d_91.tmp_0": [1, 24, 10, 10],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
"conv2d_124.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
"elementwise_add_7": [1, 56, 2, 2],
"nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
}
max_input_shape = {
"x": [1, 3, 1536, 1536],
"conv2d_92.tmp_0": [1, 120, 400, 400],
"conv2d_91.tmp_0": [1, 24, 200, 200],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
"conv2d_124.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
"elementwise_add_7": [1, 56, 400, 400],
"nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 120, 160, 160],
"conv2d_91.tmp_0": [1, 24, 80, 80],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
"conv2d_124.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
"elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
}
min_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
}
max_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
}
opt_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
}
min_input_shape.update(min_pact_shape)
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
use_dynamic_shape = False
imgH = int(args.rec_image_shape.split(',')[-2])
min_input_shape = {"x": [1, 3, imgH, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
config.exp_disable_tensorrt_ops(["transpose2"])
elif mode == "cls":
min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
use_dynamic_shape = False
if use_dynamic_shape:
config.set_trt_dynamic_shape_info(
min_input_shape, max_input_shape, opt_input_shape)
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
else:
config.disable_gpu()
if hasattr(args, "cpu_threads"):
config.set_cpu_math_library_num_threads(args.cpu_threads)
else:
# default cpu threads as 10
config.set_cpu_math_library_num_threads(10)
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if args.precision == "fp16":
config.enable_mkldnn_bfloat16()
# enable memory optim
config.enable_memory_optim()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass")
if mode == 'table':
config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False)
config.switch_ir_optim(True)
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
output_name = 'softmax_0.tmp_0'
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return output_tensors
def get_infer_gpuid():
sysstr = platform.system()
if sysstr == "Windows":
return 0
if not paddle.fluid.core.is_compiled_with_rocm():
cmd = "env | grep CUDA_VISIBLE_DEVICES"
else:
cmd = "env | grep HIP_VISIBLE_DEVICES"
env_cuda = os.popen(cmd).readlines()
if len(env_cuda) == 0:
return 0
else:
gpu_id = env_cuda[0].strip().split("=")[1]
return int(gpu_id[0])
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
return src_im
def draw_ocr_box_txt(image,
boxes,
txts,
scores=None,
drop_score=0.5,
font_path="./doc/simfang.ttf"):
h, w = image.height, image.width
img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255))
import random
random.seed(0)
draw_left = ImageDraw.Draw(img_left)
draw_right = ImageDraw.Draw(img_right)
for idx, (box, txt) in enumerate(zip(boxes, txts)):
if scores is not None and scores[idx] < drop_score:
continue
color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255))
draw_left.polygon(box, fill=color)
draw_right.polygon(
[
box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
box[2][1], box[3][0], box[3][1]
],
outline=color)
box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
1])**2)
box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
1])**2)
if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
cur_y = box[0][1]
for c in txt:
char_size = font.getsize(c)
draw_right.text(
(box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
cur_y += char_size[1]
else:
font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
draw_right.text(
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
return np.array(img_show)
def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
if __name__ == '__main__':
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import platform
import yaml
import time
import datetime
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
from ppocr.utils import profiler
from ppocr.data import build_dataloader
import numpy as np
def str2bool(v):
return v.lower() in ("true", "t", "1")
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
self.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format ' \
'\"key1=value1;key2=value2;key3=value3\".'
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=')
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
return config
def merge_config(config, opts):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in opts.items():
if "." not in key:
if isinstance(value, dict) and key in config:
config[key].update(value)
else:
config[key] = value
else:
sub_keys = key.split('.')
assert (
sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: " \
"{}, please check your running command".format(
config.keys(), sub_keys[0])
cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
return config
def check_device(use_gpu, use_xpu=False):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config {} cannot be set as true while your paddle " \
"is not compiled with {} ! \nPlease try: \n" \
"\t1. Install paddlepaddle to run model on {} \n" \
"\t2. Set {} as false in config file to run " \
"model on CPU"
try:
if use_gpu and use_xpu:
print("use_xpu and use_gpu can not both be ture.")
if use_gpu and not paddle.is_compiled_with_cuda():
print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
sys.exit(1)
if use_xpu and not paddle.device.is_compiled_with_xpu():
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1)
except Exception as e:
pass
def check_xpu(use_xpu):
"""
Log error and exit when set use_xpu=true in paddlepaddle
cpu/gpu version.
"""
err = "Config use_xpu cannot be set as true while you are " \
"using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-xpu to run model on XPU \n" \
"\t2. Set use_xpu as false in config file to run " \
"model on CPU/GPU"
try:
if use_xpu and not paddle.is_compiled_with_xpu():
print(err)
sys.exit(1)
except Exception as e:
pass
def eval(model,
valid_dataloader,
post_process_class,
eval_class):
total_frame = 0.0
total_time = 0.0
pbar = tqdm(
total=len(valid_dataloader),
desc='eval model:',
position=0,
leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
input_name = model.get_inputs()[0].name
for idx, batch in enumerate(valid_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
images = np.array(images)
input = {input_name:images}
preds = model.run(None, input_feed=input)
batch_numpy = []
for item in batch:
batch_numpy.append(item.numpy())
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
preds = preds[0]
if eval_class.main_indicator == 'hmean':
onnx_preds = {'maps':preds}
else:
onnx_preds = preds
post_result = post_process_class(onnx_preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
metric = eval_class.get_metric()
pbar.close()
metric['fps'] = total_frame / total_time
return metric
def update_center(char_center, post_result, preds):
result, label = post_result
feats, logits = preds
logits = paddle.argmax(logits, axis=-1)
feats = feats.numpy()
logits = logits.numpy()
for idx_sample in range(len(label)):
if result[idx_sample][0] == label[idx_sample][0]:
feat = feats[idx_sample]
logit = logits[idx_sample]
for idx_time in range(len(logit)):
index = logit[idx_time]
if index in char_center.keys():
char_center[index][0] = (
char_center[index][0] * char_center[index][1] +
feat[idx_time]) / (char_center[index][1] + 1)
char_center[index][1] += 1
else:
char_center[index] = [feat[idx_time], 1]
return char_center
def get_center(model, eval_dataloader, post_process_class):
pbar = tqdm(total=len(eval_dataloader), desc='get center:')
max_iter = len(eval_dataloader) - 1 if platform.system(
) == "Windows" else len(eval_dataloader)
char_center = dict()
for idx, batch in enumerate(eval_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
#update char_center
char_center = update_center(char_center, post_result, preds)
pbar.update(1)
pbar.close()
for key in char_center.keys():
char_center[key] = char_center[key][0]
return char_center
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
config = merge_config(config, FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
config = merge_config(config, profile_dic)
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
else:
log_file = None
logger = get_logger(log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
use_xpu = config['Global'].get('use_xpu', False)
# check if set use_xpu=True in paddlepaddle cpu/gpu version
use_xpu = False
if 'use_xpu' in config['Global']:
use_xpu = config['Global']['use_xpu']
check_xpu(use_xpu)
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
]
if use_xpu:
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
else:
device = 'gpu:{}'.format(dist.ParallelEnv()
.dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu)
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
loggers = []
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config:
wandb_params = config['wandb']
else:
wandb_params = dict()
wandb_params.update({'save_dir': save_model_dir})
log_writer = WandbLogger(**wandb_params, config=config)
loggers.append(log_writer)
else:
log_writer = None
print_dict(config, logger)
if loggers:
log_writer = Loggers(loggers)
else:
log_writer = None
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, log_writer, FLAGS
import numpy as np
import cv2
from PIL import Image
import os
def rotate_cut_img(im, degree, x_center, y_center, w, h, leftAdjust=False, rightAdjust=False, alph=0.2):
# degree_ = degree * 180.0 / np.pi
# print(degree_)
right = 0
left = 0
if rightAdjust:
right = 1
if leftAdjust:
left = 1
box = (max(1, x_center - w / 2 - left * alph * (w / 2))
, y_center - h / 2, # ymin
min(x_center + w / 2 + right * alph * (w / 2), im.size[0] - 1)
, y_center + h / 2) # ymax
newW = box[2] - box[0]
newH = box[3] - box[1]
tmpImg = im.rotate(degree, center=(x_center, y_center)).crop(box)
return tmpImg, newW, newH
def crop_rect(img, rect, alph=0.15):
img = np.asarray(img)
# get the parameter of the small rectangle
# print("rect!")
# print(rect)
center, size, angle = rect[0], rect[1], rect[2]
min_size = min(size)
if angle > -45:
center, size = tuple(map(int, center)), tuple(map(int, size))
# angle-=270
size = (int(size[0] + min_size * alph), int(size[1] + min_size * alph))
height, width = img.shape[0], img.shape[1]
M = cv2.getRotationMatrix2D(center, angle, 1)
# size = tuple([int(rect[1][1]), int(rect[1][0])])
img_rot = cv2.warpAffine(img, M, (width, height))
# cv2.imwrite("debug_im/img_rot.jpg", img_rot)
img_crop = cv2.getRectSubPix(img_rot, size, center)
else:
center = tuple(map(int, center))
size = tuple([int(rect[1][1]), int(rect[1][0])])
size = (int(size[0] + min_size * alph), int(size[1] + min_size * alph))
angle -= 270
height, width = img.shape[0], img.shape[1]
M = cv2.getRotationMatrix2D(center, angle, 1)
img_rot = cv2.warpAffine(img, M, (width, height))
# cv2.imwrite("debug_im/img_rot.jpg", img_rot)
img_crop = cv2.getRectSubPix(img_rot, size, center)
img_crop = Image.fromarray(img_crop)
return img_crop
def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
if isinstance(img_path, str):
img_path = cv2.imread(img_path)
# img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
img_path = img_path.copy()
for point in result:
point = point.astype(int)
cv2.line(img_path, tuple(point[0]), tuple(point[1]), color, thickness)
cv2.line(img_path, tuple(point[1]), tuple(point[2]), color, thickness)
cv2.line(img_path, tuple(point[2]), tuple(point[3]), color, thickness)
cv2.line(img_path, tuple(point[3]), tuple(point[0]), color, thickness)
return img_path
def sort_box(boxs):
res = []
for box in boxs:
# box = [x if x>0 else 0 for x in box ]
x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
newBox = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
# sort x
newBox = sorted(newBox, key=lambda x: x[0])
x1, y1 = sorted(newBox[:2], key=lambda x: x[1])[0]
index = newBox.index([x1, y1])
newBox.pop(index)
newBox = sorted(newBox, key=lambda x: -x[1])
x4, y4 = sorted(newBox[:2], key=lambda x: x[0])[0]
index = newBox.index([x4, y4])
newBox.pop(index)
newBox = sorted(newBox, key=lambda x: -x[0])
x2, y2 = sorted(newBox[:2], key=lambda x: x[1])[0]
index = newBox.index([x2, y2])
newBox.pop(index)
newBox = sorted(newBox, key=lambda x: -x[1])
x3, y3 = sorted(newBox[:2], key=lambda x: x[0])[0]
res.append([x1, y1, x2, y2, x3, y3, x4, y4])
return res
def solve(box):
"""
绕 cx,cy点 w,h 旋转 angle 的坐标
x = cx-w/2
y = cy-h/2
x1-cx = -w/2*cos(angle) +h/2*sin(angle)
y1 -cy= -w/2*sin(angle) -h/2*cos(angle)
h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle)
w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle)
(hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy)
"""
x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
cx = (x1 + x3 + x2 + x4) / 4.0
cy = (y1 + y3 + y4 + y2) / 4.0
w = (np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)) / 2
h = (np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)) / 2
sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
angle = np.arcsin(sinA)
return angle, w, h, cx, cy
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0],\
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def app_url(version, name):
url = '/{}/{}'.format(version, name)
return url
def _check_image_file(path):
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
return any([path.lower().endswith(e) for e in img_end])
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment