Unverified Commit 3deaca9d authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into update_whl

parents 806a5c8e d6ee6bdb
...@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset): ...@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
self.mode = mode.lower()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
...@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset): ...@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
...@@ -56,16 +57,16 @@ class SimpleDataSet(Dataset): ...@@ -56,16 +57,16 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
random.seed(self.seed) if self.mode == "train" or ratio_list[idx] < 1.0:
lines = random.sample(lines, random.seed(self.seed)
round(len(lines) * ratio_list[idx])) lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines) data_lines.extend(lines)
return data_lines return data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: random.seed(self.seed)
random.seed(self.seed) random.shuffle(self.data_lines)
random.shuffle(self.data_lines)
return return
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset): ...@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
data_line, e)) data_line, e))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) # during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):
......
...@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer): ...@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
return input_ont_hot return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25): def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0] batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
......
...@@ -32,7 +32,7 @@ setup( ...@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''}, package_dir={'paddleocr': ''},
include_package_data=True, include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='2.0.2', version='2.0.3',
install_requires=requirements, install_requires=requirements,
license='Apache License 2.0', license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
...@@ -98,10 +98,10 @@ class TextClassifier(object): ...@@ -98,10 +98,10 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run() self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_result)): for rno in range(len(cls_result)):
......
...@@ -180,7 +180,7 @@ class TextDetector(object): ...@@ -180,7 +180,7 @@ class TextDetector(object):
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if self.det_algorithm == "SAST" and self.det_sast_polygon:
......
...@@ -237,7 +237,7 @@ class TextRecognizer(object): ...@@ -237,7 +237,7 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
self.predictor.try_shrink_memory()
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
...@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger): ...@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger):
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# config.enable_memory_optim() # enable memory optim
config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
......
...@@ -237,8 +237,9 @@ def train(config, ...@@ -237,8 +237,9 @@ def train(config,
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if dist.get_rank( if dist.get_rank() == 0 and (
) == 0 and global_step > 0 and global_step % print_batch_step == 0: (global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
......
...@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer): ...@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer):
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0: if len(train_dataloader) == 0:
logger.error( logger.error(
'No Images in train dataset, please check annotation file and path in the configuration file' "No Images in train dataset, please ensure\n" +
"\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+
"\t2. The annotation file and path in the configuration file are provided normally."
) )
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment