Unverified Commit 3deaca9d authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into update_whl

parents 806a5c8e d6ee6bdb
......@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
global_config = config['Global']
dataset_config = config[mode]['dataset']
......@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
......@@ -56,6 +57,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
......@@ -63,7 +65,6 @@ class SimpleDataSet(Dataset):
return data_lines
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
......@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
......
......@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size))
......
......@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='2.0.2',
version='2.0.3',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
......@@ -98,10 +98,10 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime
for rno in range(len(cls_result)):
......
......@@ -180,7 +180,7 @@ class TextDetector(object):
preds['maps'] = outputs[0]
else:
raise NotImplementedError
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon:
......
......@@ -237,7 +237,7 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
self.predictor.try_shrink_memory()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
......@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger):
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1
# config.enable_memory_optim()
# enable memory optim
config.enable_memory_optim()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
......
......@@ -237,8 +237,9 @@ def train(config,
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost /
......
......@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer):
train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0:
logger.error(
'No Images in train dataset, please check annotation file and path in the configuration file'
"No Images in train dataset, please ensure\n" +
"\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+
"\t2. The annotation file and path in the configuration file are provided normally."
)
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment