"web/extensions/vscode:/vscode.git/clone" did not exist on "d6830b958c1f37a584ddb5d91beeef54b6540ce4"
Commit 7b53596c authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph_rc

parents 0458f0cc 0e32093f
doc/joinus.PNG

15.7 KB | W: | H:

doc/joinus.PNG

408 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -35,44 +35,45 @@ __all__ = ['PaddleOCR'] ...@@ -35,44 +35,45 @@ __all__ = ['PaddleOCR']
model_urls = { model_urls = {
'det': 'det':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'rec': { 'rec': {
'ch': { 'ch': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt' 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}, },
'en': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ic15_dict.txt' 'dict_path': './ppocr/utils/dict/en_dict.txt'
}, },
'french': { 'french': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt' 'dict_path': './ppocr/utils/dict/french_dict.txt'
}, },
'german': { 'german': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt' 'dict_path': './ppocr/utils/dict/german_dict.txt'
}, },
'korean': { 'korean': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt' 'dict_path': './ppocr/utils/dict/korean_dict.txt'
}, },
'japan': { 'japan': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
} }
}, },
'cls': 'cls':
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = 2.0
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
...@@ -94,20 +95,24 @@ def download_with_progressbar(url, save_path): ...@@ -94,20 +95,24 @@ def download_with_progressbar(url, save_path):
def maybe_download(model_storage_directory, url): def maybe_download(model_storage_directory, url):
# using custom model # using custom model
if not os.path.exists(os.path.join( tar_file_name_list = [
model_storage_directory, 'model')) or not os.path.exists( 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
os.path.join(model_storage_directory, 'params')): ]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path)) print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True) os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path) download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj: with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers(): for member in tarObj.getmembers():
if "model" in member.name: filename = None
filename = 'model' for tar_file_name in tar_file_name_list:
elif "params" in member.name: if tar_file_name in member.name:
filename = 'params' filename = tar_file_name
else: if filename is None:
continue continue
file = tarObj.extractfile(member) file = tarObj.extractfile(member)
with open( with open(
...@@ -176,43 +181,43 @@ def parse_args(mMain=True, add_help=True): ...@@ -176,43 +181,43 @@ def parse_args(mMain=True, add_help=True):
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args() return parser.parse_args()
else: else:
return argparse.Namespace(use_gpu=True, return argparse.Namespace(
ir_optim=True, use_gpu=True,
use_tensorrt=False, ir_optim=True,
gpu_mem=8000, use_tensorrt=False,
image_dir='', gpu_mem=8000,
det_algorithm='DB', image_dir='',
det_model_dir=None, det_algorithm='DB',
det_limit_side_len=960, det_model_dir=None,
det_limit_type='max', det_limit_side_len=960,
det_db_thresh=0.3, det_limit_type='max',
det_db_box_thresh=0.5, det_db_thresh=0.3,
det_db_unclip_ratio=2.0, det_db_box_thresh=0.5,
det_east_score_thresh=0.8, det_db_unclip_ratio=2.0,
det_east_cover_thresh=0.1, det_east_score_thresh=0.8,
det_east_nms_thresh=0.2, det_east_cover_thresh=0.1,
rec_algorithm='CRNN', det_east_nms_thresh=0.2,
rec_model_dir=None, rec_algorithm='CRNN',
rec_image_shape="3, 32, 320", rec_model_dir=None,
rec_char_type='ch', rec_image_shape="3, 32, 320",
rec_batch_num=30, rec_char_type='ch',
max_text_length=25, rec_batch_num=30,
rec_char_dict_path=None, max_text_length=25,
use_space_char=True, rec_char_dict_path=None,
drop_score=0.5, use_space_char=True,
cls_model_dir=None, drop_score=0.5,
cls_image_shape="3, 48, 192", cls_model_dir=None,
label_list=['0', '180'], cls_image_shape="3, 48, 192",
cls_batch_num=30, label_list=['0', '180'],
cls_thresh=0.9, cls_batch_num=30,
enable_mkldnn=False, cls_thresh=0.9,
use_zero_copy_run=False, enable_mkldnn=False,
use_pdserving=False, use_zero_copy_run=False,
lang='ch', use_pdserving=False,
det=True, lang='ch',
rec=True, det=True,
use_angle_cls=False rec=True,
) use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -228,19 +233,21 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -228,19 +233,21 @@ class PaddleOCR(predict_system.TextSystem):
lang = postprocess_params.lang lang = postprocess_params.lang
assert lang in model_urls[ assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
if postprocess_params.rec_char_dict_path is None: if postprocess_params.rec_char_dict_path is None:
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path'] 'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') postprocess_params.det_model_dir = os.path.join(
BASE_DIR, '{}/det'.format(VERSION))
if postprocess_params.rec_model_dir is None: if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join( postprocess_params.rec_model_dir = os.path.join(
BASE_DIR, 'rec/{}'.format(lang)) BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
if postprocess_params.cls_model_dir is None: if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') postprocess_params.cls_model_dir = os.path.join(
BASE_DIR, '{}/cls'.format(VERSION))
print(postprocess_params) print(postprocess_params)
# download model # download model
maybe_download(postprocess_params.det_model_dir, model_urls['det']) maybe_download(postprocess_params.det_model_dir, model_urls['det'])
......
...@@ -67,6 +67,7 @@ def build_dataloader(config, mode, device, logger): ...@@ -67,6 +67,7 @@ def build_dataloader(config, mode, device, logger):
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']
num_workers = loader_config['num_workers'] num_workers = loader_config['num_workers']
use_shared_memory = False
if mode == "Train": if mode == "Train":
#Distribute data to multiple cards #Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
...@@ -74,6 +75,7 @@ def build_dataloader(config, mode, device, logger): ...@@ -74,6 +75,7 @@ def build_dataloader(config, mode, device, logger):
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=drop_last) drop_last=drop_last)
use_shared_memory = True
else: else:
#Distribute data to single card #Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
...@@ -87,6 +89,7 @@ def build_dataloader(config, mode, device, logger): ...@@ -87,6 +89,7 @@ def build_dataloader(config, mode, device, logger):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True,
use_shared_memory=use_shared_memory)
return data_loader return data_loader
...@@ -35,12 +35,13 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort ...@@ -35,12 +35,13 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort
class RecAug(object): class RecAug(object):
def __init__(self, use_tia=True, **kwargsz): def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
self.use_tia = use_tia self.use_tia = use_tia
self.aug_prob = aug_prob
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
img = warp(img, 10, self.use_tia) img = warp(img, 10, self.use_tia, self.aug_prob)
data['image'] = img data['image'] = img
return data return data
...@@ -329,7 +330,7 @@ def get_warpAffine(config): ...@@ -329,7 +330,7 @@ def get_warpAffine(config):
return rz return rz
def warp(img, ang, use_tia=True): def warp(img, ang, use_tia=True, prob=0.4):
""" """
warp warp
""" """
...@@ -338,8 +339,6 @@ def warp(img, ang, use_tia=True): ...@@ -338,8 +339,6 @@ def warp(img, ang, use_tia=True):
config.make(w, h, ang) config.make(w, h, ang)
new_img = img new_img = img
prob = 0.4
if config.distort: if config.distort:
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20: if random.random() <= prob and img_height >= 20 and img_width >= 20:
......
...@@ -32,9 +32,8 @@ class ClsMetric(object): ...@@ -32,9 +32,8 @@ class ClsMetric(object):
def get_metric(self): def get_metric(self):
""" """
return metircs { return metrics {
'acc': 0, 'acc': 0
'norm_edit_dis': 0,
} }
""" """
acc = self.correct_num / self.all_num acc = self.correct_num / self.all_num
......
...@@ -57,7 +57,7 @@ class DetMetric(object): ...@@ -57,7 +57,7 @@ class DetMetric(object):
def get_metric(self): def get_metric(self):
""" """
return metircs { return metrics {
'precision': 0, 'precision': 0,
'recall': 0, 'recall': 0,
'hmean': 0 'hmean': 0
......
...@@ -43,7 +43,7 @@ class RecMetric(object): ...@@ -43,7 +43,7 @@ class RecMetric(object):
def get_metric(self): def get_metric(self):
""" """
return metircs { return metrics {
'acc': 0, 'acc': 0,
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
......
...@@ -16,8 +16,8 @@ from __future__ import absolute_import ...@@ -16,8 +16,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
import paddle
__all__ = ['build_optimizer'] __all__ = ['build_optimizer']
...@@ -49,7 +49,13 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -49,7 +49,13 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step3 build optimizer # step3 build optimizer
optim_name = config.pop('name') optim_name = config.pop('name')
if 'clip_norm' in config:
clip_norm = config.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr, optim = getattr(optimizer, optim_name)(learning_rate=lr,
weight_decay=reg, weight_decay=reg,
grad_clip=grad_clip,
**config) **config)
return optim(parameters), lr return optim(parameters), lr
...@@ -30,18 +30,25 @@ class Momentum(object): ...@@ -30,18 +30,25 @@ class Momentum(object):
regularization (WeightDecayRegularizer, optional) - The strategy of regularization. regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
""" """
def __init__(self, learning_rate, momentum, weight_decay=None, **args): def __init__(self,
learning_rate,
momentum,
weight_decay=None,
grad_clip=None,
**args):
super(Momentum, self).__init__() super(Momentum, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, parameters):
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
parameters=parameters, weight_decay=self.weight_decay,
weight_decay=self.weight_decay) grad_clip=self.grad_clip,
parameters=parameters)
return opt return opt
...@@ -96,10 +103,11 @@ class RMSProp(object): ...@@ -96,10 +103,11 @@ class RMSProp(object):
def __init__(self, def __init__(self,
learning_rate, learning_rate,
momentum, momentum=0.0,
rho=0.95, rho=0.95,
epsilon=1e-6, epsilon=1e-6,
weight_decay=None, weight_decay=None,
grad_clip=None,
**args): **args):
super(RMSProp, self).__init__() super(RMSProp, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
...@@ -107,6 +115,7 @@ class RMSProp(object): ...@@ -107,6 +115,7 @@ class RMSProp(object):
self.rho = rho self.rho = rho
self.epsilon = epsilon self.epsilon = epsilon
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, parameters):
opt = optim.RMSProp( opt = optim.RMSProp(
...@@ -115,5 +124,6 @@ class RMSProp(object): ...@@ -115,5 +124,6 @@ class RMSProp(object):
rho=self.rho, rho=self.rho,
epsilon=self.epsilon, epsilon=self.epsilon,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
parameters=parameters) parameters=parameters)
return opt return opt
...@@ -40,7 +40,7 @@ class DBPostProcess(object): ...@@ -40,7 +40,7 @@ class DBPostProcess(object):
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.dilation_kernel = None if not use_dilation else [[1, 1], [1, 1]] self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import argparse
import json
def gen_rec_label(input_path, out_label):
with open(out_label, 'w') as out_file:
with open(input_path, 'r') as f:
for line in f.readlines():
tmp = line.strip('\n').replace(" ", "").split(',')
img_path, label = tmp[0], tmp[1]
label = label.replace("\"", "")
out_file.write(img_path + '\t' + label + '\n')
def gen_det_label(root_path, input_dir, out_label):
with open(out_label, 'w') as out_file:
for label_file in os.listdir(input_dir):
img_path = root_path + label_file[3:-4] + ".jpg"
label = []
with open(os.path.join(input_dir, label_file), 'r') as f:
for line in f.readlines():
tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
"").split(',')
points = tmp[:8]
s = []
for i in range(0, len(points), 2):
b = points[i:i + 2]
b = [int(t) for t in b]
s.append(b)
result = {"transcription": tmp[8], "points": s}
label.append(result)
out_file.write(img_path + '\t' + json.dumps(
label, ensure_ascii=False) + '\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--mode',
type=str,
default="rec",
help='Generate rec_label or det_label, can be set rec or det')
parser.add_argument(
'--root_path',
type=str,
default=".",
help='The root directory of images.Only takes effect when mode=det ')
parser.add_argument(
'--input_path',
type=str,
default=".",
help='Input_label or input path to be converted')
parser.add_argument(
'--output_label',
type=str,
default="out_label.txt",
help='Output file name')
args = parser.parse_args()
if args.mode == "rec":
print("Generate rec label")
gen_rec_label(args.input_path, args.output_label)
elif args.mode == "det":
gen_det_label(args.root_path, args.input_path, args.output_label)
...@@ -33,4 +33,4 @@ v ...@@ -33,4 +33,4 @@ v
w w
x x
y y
z z
\ No newline at end of file
...@@ -28,20 +28,13 @@ from ppocr.modeling.architectures import build_model ...@@ -28,20 +28,13 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.program import load_config from tools.program import load_config, merge_config, ArgsParser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
def main(): def main():
FLAGS = parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger() logger = get_logger()
# build post process # build post process
...@@ -57,8 +50,7 @@ def main(): ...@@ -57,8 +50,7 @@ def main():
init_model(config, model, logger) init_model(config, model, logger)
model.eval() model.eval()
save_path = '{}/{}/inference'.format(FLAGS.output_path, save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
config['Architecture']['model_type'])
infer_shape = [3, 32, 100] if config['Architecture'][ infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640] 'model_type'] != "det" else [3, 640, 640]
model = to_static( model = to_static(
......
...@@ -63,6 +63,7 @@ class TextDetector(object): ...@@ -63,6 +63,7 @@ class TextDetector(object):
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -111,7 +112,7 @@ class TextDetector(object): ...@@ -111,7 +112,7 @@ class TextDetector(object):
box = self.clip_det_res(box, img_height, img_width) box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1])) rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3])) rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10: if rect_width <= 3 or rect_height <= 3:
continue continue
dt_boxes_new.append(box) dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new) dt_boxes = np.array(dt_boxes_new)
...@@ -186,4 +187,4 @@ if __name__ == "__main__": ...@@ -186,4 +187,4 @@ if __name__ == "__main__":
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
if count > 1: if count > 1:
logger.info("Avg Time:", total_time / (count - 1)) logger.info("Avg Time: {}".format(total_time / (count - 1)))
...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger): ...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0) sys.exit(0)
model_file_path = model_dir + ".pdmodel" model_file_path = model_dir + "/inference.pdmodel"
params_file_path = model_dir + ".pdiparams" params_file_path = model_dir + "/inference.pdiparams"
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
logger.info("not find model file path {}".format(model_file_path)) logger.info("not find model file path {}".format(model_file_path))
sys.exit(0) sys.exit(0)
......
...@@ -113,7 +113,6 @@ def merge_config(config): ...@@ -113,7 +113,6 @@ def merge_config(config):
global_config.keys(), sub_keys[0]) global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur)
if idx == len(sub_keys) - 2: if idx == len(sub_keys) - 2:
cur[sub_key] = value cur[sub_key] = value
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment