"include/composable_kernel/utility/Sequence.hpp" did not exist on "eafdabba771fe3a41843b86c85f6e574b2d176e1"
Commit 85a98fe2 authored by tink2123's avatar tink2123
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

parents d517e1b7 f8889760
...@@ -41,6 +41,7 @@ class TextRecognizer(object): ...@@ -41,6 +41,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.max_text_length = args.max_text_length
postprocess_params = { postprocess_params = {
'name': 'CTCLabelDecode', 'name': 'CTCLabelDecode',
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
...@@ -186,8 +187,9 @@ class TextRecognizer(object): ...@@ -186,8 +187,9 @@ class TextRecognizer(object):
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
else: else:
norm_img = self.process_image_srn( norm_img = self.process_image_srn(img_list[indices[ino]],
img_list[indices[ino]], self.rec_image_shape, 8, 25) self.rec_image_shape, 8,
self.max_text_length)
encoder_word_pos_list = [] encoder_word_pos_list = []
gsrm_word_pos_list = [] gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = [] gsrm_slf_attn_bias1_list = []
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -141,6 +142,7 @@ def sorted_boxes(dt_boxes): ...@@ -141,6 +142,7 @@ def sorted_boxes(dt_boxes):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
font_path = args.vis_font_path font_path = args.vis_font_path
...@@ -184,4 +186,18 @@ def main(args): ...@@ -184,4 +186,18 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) args = utility.parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
...@@ -48,6 +48,7 @@ def parse_args(): ...@@ -48,6 +48,7 @@ def parse_args():
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -74,6 +75,20 @@ def parse_args(): ...@@ -74,6 +75,20 @@ def parse_args():
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5) parser.add_argument("--drop_score", type=float, default=0.5)
# params for e2e
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
parser.add_argument("--e2e_model_dir", type=str)
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
parser.add_argument("--e2e_limit_type", type=str, default='max')
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument(
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str) parser.add_argument("--cls_model_dir", type=str)
...@@ -85,6 +100,10 @@ def parse_args(): ...@@ -85,6 +100,10 @@ def parse_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
return parser.parse_args() return parser.parse_args()
...@@ -93,8 +112,10 @@ def create_predictor(args, mode, logger): ...@@ -93,8 +112,10 @@ def create_predictor(args, mode, logger):
model_dir = args.det_model_dir model_dir = args.det_model_dir
elif mode == 'cls': elif mode == 'cls':
model_dir = args.cls_model_dir model_dir = args.cls_model_dir
else: elif mode == 'rec':
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
else:
model_dir = args.e2e_model_dir
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
...@@ -148,6 +169,22 @@ def create_predictor(args, mode, logger): ...@@ -148,6 +169,22 @@ def create_predictor(args, mode, logger):
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
def draw_e2e_res(dt_boxes, strs, img_path):
src_im = cv2.imread(img_path)
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
return src_im
def draw_text_det_res(dt_boxes, img_path): def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path) src_im = cv2.imread(img_path)
for box in dt_boxes: for box in dt_boxes:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
if len(dt_boxes) > 0:
src_im = img
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The e2e Image saved in {}".format(save_path))
def main():
global_config = config['Global']
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
with open(save_res_path, "wb") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['texts']
# write resule
dt_boxes_json = []
for poly, str in zip(points, strs):
tmp_json = {"transcription": str}
tmp_json['points'] = poly.tolist()
dt_boxes_json.append(tmp_json)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
draw_e2e_res(points, strs, config, src_img, file)
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
...@@ -73,35 +73,45 @@ def main(): ...@@ -73,35 +73,45 @@ def main():
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path',
"./output/rec/predicts_rec.txt")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file)) with open(save_res_path, "w") as fout:
with open(file, 'rb') as f: for file in get_image_file_list(config['Global']['infer_img']):
img = f.read() logger.info("infer_img: {}".format(file))
data = {'image': img} with open(file, 'rb') as f:
batch = transform(data, ops) img = f.read()
if config['Architecture']['algorithm'] == "SRN": data = {'image': img}
encoder_word_pos_list = np.expand_dims(batch[1], axis=0) batch = transform(data, ops)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) if config['Architecture']['algorithm'] == "SRN":
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
others = [ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list), others = [
paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list) paddle.to_tensor(gsrm_word_pos_list),
] paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
images = np.expand_dims(batch[0], axis=0) ]
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN": images = np.expand_dims(batch[0], axis=0)
preds = model(images, others) images = paddle.to_tensor(images)
else: if config['Architecture']['algorithm'] == "SRN":
preds = model(images) preds = model(images, others)
post_result = post_process_class(preds) else:
for rec_reuslt in post_result: preds = model(images)
logger.info('\t result: {}'.format(rec_reuslt)) post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
if len(rec_reuslt) >= 2:
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
rec_reuslt[1]) + "\n")
logger.info("success!") logger.info("success!")
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import sys import sys
import platform
import yaml import yaml
import time import time
import shutil import shutil
...@@ -159,6 +160,8 @@ def train(config, ...@@ -159,6 +160,8 @@ def train(config,
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
global_step = 0 global_step = 0
if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step']
start_eval_step = 0 start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
...@@ -196,9 +199,11 @@ def train(config, ...@@ -196,9 +199,11 @@ def train(config,
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 batch_sum = 0
batch_start = time.time() batch_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader): if idx >= max_iter:
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
...@@ -287,7 +292,8 @@ def train(config, ...@@ -287,7 +292,8 @@ def train(config,
is_best=True, is_best=True,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join([ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items() '{}: {}'.format(k, v) for k, v in best_model_dict.items()
])) ]))
...@@ -309,7 +315,8 @@ def train(config, ...@@ -309,7 +315,8 @@ def train(config,
is_best=False, is_best=False,
prefix='latest', prefix='latest',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model( save_model(
model, model,
...@@ -319,7 +326,8 @@ def train(config, ...@@ -319,7 +326,8 @@ def train(config,
is_best=False, is_best=False,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
...@@ -335,8 +343,10 @@ def eval(model, valid_dataloader, post_process_class, eval_class, ...@@ -335,8 +343,10 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= max_iter:
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
...@@ -375,7 +385,8 @@ def preprocess(is_train=False): ...@@ -375,7 +385,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment