Unverified Commit d354020d authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge branch 'dygraph' into fix_doc

parents 683cb556 76946e83
...@@ -180,7 +180,7 @@ class TextDetector(object): ...@@ -180,7 +180,7 @@ class TextDetector(object):
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if self.det_algorithm == "SAST" and self.det_sast_polygon:
......
...@@ -237,7 +237,7 @@ class TextRecognizer(object): ...@@ -237,7 +237,7 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
self.predictor.try_shrink_memory()
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
...@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger): ...@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger):
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# config.enable_memory_optim() # enable memory optim
config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
......
...@@ -237,8 +237,9 @@ def train(config, ...@@ -237,8 +237,9 @@ def train(config,
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if dist.get_rank( if dist.get_rank() == 0 and (
) == 0 and global_step > 0 and global_step % print_batch_step == 0: (global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
......
...@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer): ...@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer):
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0: if len(train_dataloader) == 0:
logger.error( logger.error(
'No Images in train dataset, please check annotation file and path in the configuration file' "No Images in train dataset, please ensure\n" +
"\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+
"\t2. The annotation file and path in the configuration file are provided normally."
) )
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment