Commit 406463ef authored by Khanh Tran's avatar Khanh Tran
Browse files

update from original repo

parents 4d22bf3a bc85ebd4
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if os.environ.get(key, None) is None: if os.environ.get(key, None) is None:
...@@ -54,6 +55,7 @@ def main(): ...@@ -54,6 +55,7 @@ def main():
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
logger.info(config) logger.info(config)
char_ops = CharacterOps(config['Global']) char_ops = CharacterOps(config['Global'])
loss_type = config['Global']['loss_type']
config['Global']['char_ops'] = char_ops config['Global']['char_ops'] = char_ops
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
...@@ -78,35 +80,44 @@ def main(): ...@@ -78,35 +80,44 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img'] infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img) infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:",infer_list[i]) print("infer_img:%s" % infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
return_numpy=False) return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0]) preds = np.array(predict[0])
if preds.shape[1] == 1:
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_lod = predict[0].lod()[0] preds_lod = predict[0].lod()[0]
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
else: probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention":
preds = np.array(predict[0])
probs = np.array(predict[1])
end_pos = np.where(preds[0, :] == 1)[0] end_pos = np.where(preds[0, :] == 1)[0]
if len(end_pos) <= 1: if len(end_pos) <= 1:
preds_text = preds[0, 1:] preds = preds[0, 1:]
score = np.mean(probs[0, 1:])
else: else:
preds_text = preds[0, 1:end_pos[1]] preds = preds[0, 1:end_pos[1]]
preds_text = preds_text.reshape(-1) score = np.mean(probs[0, 1:end_pos[1]])
preds_text = char_ops.decode(preds_text) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:",preds) print("\t index:", preds)
print("\t word :",preds_text) print("\t word :", preds_text)
print("\t score :", score)
# save for inference model # save for inference model
target_var = [] target_var = []
......
...@@ -114,7 +114,7 @@ def merge_config(config): ...@@ -114,7 +114,7 @@ def merge_config(config):
global_config[key] = value global_config[key] = value
else: else:
sub_keys = key.split('.') sub_keys = key.split('.')
assert (sub_keys[0] in global_config) assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur) assert (sub_key in cur)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment