Unverified Commit 4cac91eb authored by dyning's avatar dyning Committed by GitHub
Browse files

Merge pull request #132 from tink2123/add_rec_score

Add rec score
parents ddefd24d 9393a1b3
...@@ -41,13 +41,18 @@ class LMDBReader(object): ...@@ -41,13 +41,18 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.drop_last = False
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == "eval": self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test": self.drop_last = False
self.batch_size = 1 self.infer_img = params['infer_img']
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -100,13 +105,18 @@ class LMDBReader(object): ...@@ -100,13 +105,18 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
...@@ -126,9 +136,13 @@ class LMDBReader(object): ...@@ -126,9 +136,13 @@ class LMDBReader(object):
if sample_info is None: if sample_info is None:
continue continue
img, label = sample_info img, label = sample_info
outs = process_image(img, self.image_shape, label, outs = process_image(
self.char_ops, self.loss_type, img=img,
self.max_text_length) image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None: if outs is None:
continue continue
yield outs yield outs
...@@ -136,6 +150,7 @@ class LMDBReader(object): ...@@ -136,6 +150,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
...@@ -143,10 +158,11 @@ class LMDBReader(object): ...@@ -143,10 +158,11 @@ class LMDBReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -165,26 +181,34 @@ class SimpleReader(object): ...@@ -165,26 +181,34 @@ class SimpleReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval': self.drop_last = True
self.batch_size = params['test_batch_size_per_card']
else: else:
self.batch_size = 1 self.batch_size = params['test_batch_size_per_card']
self.infer_img = params['infer_img'] self.drop_last = False
def __call__(self, process_id): def __call__(self, process_id):
if self.mode != 'train': if self.mode != 'train':
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
...@@ -192,7 +216,7 @@ class SimpleReader(object): ...@@ -192,7 +216,7 @@ class SimpleReader(object):
img_num = len(label_infor_list) img_num = len(label_infor_list)
img_id_list = list(range(img_num)) img_id_list = list(range(img_num))
random.shuffle(img_id_list) random.shuffle(img_id_list)
if sys.platform=="win32": if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows." print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.") "num_workers will be 1.")
self.num_workers = 1 self.num_workers = 1
...@@ -204,7 +228,7 @@ class SimpleReader(object): ...@@ -204,7 +228,7 @@ class SimpleReader(object):
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
continue continue
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1] label = substr[1]
...@@ -222,9 +246,10 @@ class SimpleReader(object): ...@@ -222,9 +246,10 @@ class SimpleReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape): ...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value): def get_img_data(value):
"""get_img_data""" """get_img_data"""
if not value: if not value:
...@@ -66,8 +92,13 @@ def process_image(img, ...@@ -66,8 +92,13 @@ def process_image(img,
label=None, label=None,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None,
tps=None,
infer_mode=False):
if not infer_mode or char_ops.character_type == "en" or tps != None:
norm_img = resize_norm_img(img, image_shape) norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() char_num = char_ops.get_char_num()
......
...@@ -30,6 +30,8 @@ class RecModel(object): ...@@ -30,6 +30,8 @@ class RecModel(object):
global_params = params['Global'] global_params = params['Global']
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
...@@ -60,8 +62,8 @@ class RecModel(object): ...@@ -60,8 +62,8 @@ class RecModel(object):
def create_feed(self, mode): def create_feed(self, mode):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
...@@ -86,6 +88,16 @@ class RecModel(object): ...@@ -86,6 +88,16 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set img_shape to be the same , it may affect the inference effect"
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader
...@@ -110,7 +122,11 @@ class RecModel(object): ...@@ -110,7 +122,11 @@ class RecModel(object):
return loader, outputs return loader, outputs
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
return loader, {'decoded_out': decoded_out} predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}
...@@ -123,6 +123,8 @@ class AttentionPredict(object): ...@@ -123,6 +123,8 @@ class AttentionPredict(object):
full_ids = fluid.layers.fill_constant_batch_size_like( full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1) input=init_state, shape=[-1, 1], dtype='int64', value=1)
full_scores = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='float32', value=1)
cond = layers.less_than(x=counter, y=array_len) cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
...@@ -171,6 +173,9 @@ class AttentionPredict(object): ...@@ -171,6 +173,9 @@ class AttentionPredict(object):
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1) new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids) fluid.layers.assign(new_ids, full_ids)
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
fluid.layers.assign(new_scores, full_scores)
layers.increment(x=counter, value=1, in_place=True) layers.increment(x=counter, value=1, in_place=True)
# update the memories # update the memories
...@@ -184,7 +189,7 @@ class AttentionPredict(object): ...@@ -184,7 +189,7 @@ class AttentionPredict(object):
length_cond = layers.less_than(x=counter, y=array_len) length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices)) finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids return full_ids, full_scores
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
...@@ -223,10 +228,10 @@ class AttentionPredict(object): ...@@ -223,10 +228,10 @@ class AttentionPredict(object):
decoder_size, char_num) decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1) _, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out) decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out} predicts = {'predict':predict, 'decoded_out':decoded_out}
else: else:
ids = self.gru_attention_infer( ids, predict = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim, decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size) encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids} predicts = {'predict':predict, 'decoded_out':ids}
return predicts return predicts
...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_sample_num = 0 total_sample_num = 0
total_acc_num = 0 total_acc_num = 0
total_batch_num = 0 total_batch_num = 0
if mode == "test": if mode == "eval":
is_remove_duplicate = False is_remove_duplicate = False
else: else:
is_remove_duplicate = True is_remove_duplicate = True
...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): ...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict):
total_correct_number = 0 total_correct_number = 0
eval_data_acc_info = {} eval_data_acc_info = {}
for eval_data in eval_data_list: for eval_data in eval_data_list:
config['EvalReader']['lmdb_sets_dir'] = \ config['TestReader']['lmdb_sets_dir'] = \
eval_data_dir + "/" + eval_data eval_data_dir + "/" + eval_data
eval_reader = reader_main(config=config, mode="eval") eval_reader = reader_main(config=config, mode="test")
eval_info_dict['reader'] = eval_reader eval_info_dict['reader'] = eval_reader
metrics = eval_rec_run(exe, config, eval_info_dict, "eval") metrics = eval_rec_run(exe, config, eval_info_dict, "test")
total_evaluation_data_number += metrics['total_sample_num'] total_evaluation_data_number += metrics['total_sample_num']
total_correct_number += metrics['total_acc_num'] total_correct_number += metrics['total_acc_num']
eval_data_acc_info[eval_data] = metrics eval_data_acc_info[eval_data] = metrics
......
...@@ -32,10 +32,16 @@ class TextRecognizer(object): ...@@ -32,10 +32,16 @@ class TextRecognizer(object):
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
char_ops_params = {} char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path char_ops_params["character_dict_path"] = args.rec_char_dict_path
if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc' char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -80,13 +86,14 @@ class TextRecognizer(object): ...@@ -80,13 +86,14 @@ class TextRecognizer(object):
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0] rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0] predict_lod = self.output_tensors[1].lod()[0]
elapse = time.time() - starttime elapse = time.time() - starttime
predict_time += elapse predict_time += elapse
starttime = time.time()
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno] beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1] end = rec_idx_lod[rno + 1]
...@@ -100,6 +107,22 @@ class TextRecognizer(object): ...@@ -100,6 +107,22 @@ class TextRecognizer(object):
valid_ind = np.where(ind != (blank - 1))[0] valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score]) rec_res.append([preds_text, score])
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score])
return rec_res, predict_time return rec_res, predict_time
...@@ -116,7 +139,17 @@ if __name__ == "__main__": ...@@ -116,7 +139,17 @@ if __name__ == "__main__":
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if os.environ.get(key, None) is None: if os.environ.get(key, None) is None:
...@@ -54,6 +55,7 @@ def main(): ...@@ -54,6 +55,7 @@ def main():
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
logger.info(config) logger.info(config)
char_ops = CharacterOps(config['Global']) char_ops = CharacterOps(config['Global'])
loss_type = config['Global']['loss_type']
config['Global']['char_ops'] = char_ops config['Global']['char_ops'] = char_ops
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
...@@ -78,35 +80,44 @@ def main(): ...@@ -78,35 +80,44 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img'] infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img) infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:",infer_list[i]) print("infer_img:%s" % infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
return_numpy=False) return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0]) preds = np.array(predict[0])
if preds.shape[1] == 1:
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_lod = predict[0].lod()[0] preds_lod = predict[0].lod()[0]
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
else: probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention":
preds = np.array(predict[0])
probs = np.array(predict[1])
end_pos = np.where(preds[0, :] == 1)[0] end_pos = np.where(preds[0, :] == 1)[0]
if len(end_pos) <= 1: if len(end_pos) <= 1:
preds_text = preds[0, 1:] preds = preds[0, 1:]
score = np.mean(probs[0, 1:])
else: else:
preds_text = preds[0, 1:end_pos[1]] preds = preds[0, 1:end_pos[1]]
preds_text = preds_text.reshape(-1) score = np.mean(probs[0, 1:end_pos[1]])
preds_text = char_ops.decode(preds_text) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:",preds) print("\t index:", preds)
print("\t word :",preds_text) print("\t word :", preds_text)
print("\t score :", score)
# save for inference model # save for inference model
target_var = [] target_var = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment