Commit 0afe6c32 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'upstream/dygraph' into dy1

parents 453f9e5c 32665fe5
...@@ -36,7 +36,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -36,7 +36,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -114,7 +114,7 @@ for line in result: ...@@ -114,7 +114,7 @@ for line in result:
from PIL import Image from PIL import Image
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -253,7 +253,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -253,7 +253,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -285,7 +285,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -285,7 +285,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -314,7 +314,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -314,7 +314,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
......
...@@ -31,7 +31,9 @@ On Total-Text dataset, the text detection result is as follows: ...@@ -31,7 +31,9 @@ On Total-Text dataset, the text detection result is as follows:
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
|SAST|ResNet50_vd|89.63%|78.44%|83.66%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| |SAST|ResNet50_vd|89.63%|78.44%|83.66%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi). **Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from:
* [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
* [Google Drive](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing)
For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./detection_en.md) For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./detection_en.md)
......
...@@ -35,7 +35,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -35,7 +35,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -116,7 +116,7 @@ for line in result: ...@@ -116,7 +116,7 @@ for line in result:
from PIL import Image from PIL import Image
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -262,7 +262,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -262,7 +262,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -292,7 +292,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -292,7 +292,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -320,7 +320,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -320,7 +320,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
......
...@@ -236,7 +236,9 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -236,7 +236,9 @@ class PaddleOCR(predict_system.TextSystem):
assert lang in model_urls[ assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
use_inner_dict = False
if postprocess_params.rec_char_dict_path is None: if postprocess_params.rec_char_dict_path is None:
use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path'] 'dict_path']
...@@ -263,7 +265,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -263,7 +265,7 @@ class PaddleOCR(predict_system.TextSystem):
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0) sys.exit(0)
if use_inner_dict:
postprocess_params.rec_char_dict_path = str( postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path) Path(__file__).parent / postprocess_params.rec_char_dict_path)
...@@ -282,8 +284,13 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -282,8 +284,13 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == False:
self.use_angle_cls = False
elif cls == True and self.use_angle_cls == False:
logger.warning(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
)
self.use_angle_cls = cls
if isinstance(img, str): if isinstance(img, str):
# download net image # download net image
if img.startswith('http'): if img.startswith('http'):
......
...@@ -32,7 +32,6 @@ class MakeShrinkMap(object): ...@@ -32,7 +32,6 @@ class MakeShrinkMap(object):
text_polys, ignore_tags = self.validate_polygons(text_polys, text_polys, ignore_tags = self.validate_polygons(text_polys,
ignore_tags, h, w) ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32) gt = np.zeros((h, w), dtype=np.float32)
# gt = np.zeros((1, h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32) mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)): for i in range(len(text_polys)):
polygon = text_polys[i] polygon = text_polys[i]
......
...@@ -117,13 +117,16 @@ class RawRandAugment(object): ...@@ -117,13 +117,16 @@ class RawRandAugment(object):
class RandAugment(RawRandAugment): class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """ """ RandAugment wrapper to auto fit different img types """
def __init__(self, *args, **kwargs): def __init__(self, prob=0.5, *args, **kwargs):
self.prob = prob
if six.PY2: if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs) super(RandAugment, self).__init__(*args, **kwargs)
else: else:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __call__(self, data): def __call__(self, data):
if np.random.rand() > self.prob:
return data
img = data['image'] img = data['image']
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
......
...@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset): ...@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
self.mode = mode.lower()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
...@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset): ...@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
...@@ -56,6 +57,7 @@ class SimpleDataSet(Dataset): ...@@ -56,6 +57,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed) random.seed(self.seed)
lines = random.sample(lines, lines = random.sample(lines,
round(len(lines) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
...@@ -63,7 +65,6 @@ class SimpleDataSet(Dataset): ...@@ -63,7 +65,6 @@ class SimpleDataSet(Dataset):
return data_lines return data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
...@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset): ...@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
data_line, e)) data_line, e))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) # during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):
......
...@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer): ...@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
return input_ont_hot return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25): def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0] batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
......
...@@ -32,7 +32,7 @@ setup( ...@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''}, package_dir={'paddleocr': ''},
include_package_data=True, include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='2.0.2', version='2.0.3',
install_requires=requirements, install_requires=requirements,
license='Apache License 2.0', license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
...@@ -98,10 +98,10 @@ class TextClassifier(object): ...@@ -98,10 +98,10 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run() self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_result)): for rno in range(len(cls_result)):
......
...@@ -39,10 +39,7 @@ class TextDetector(object): ...@@ -39,10 +39,7 @@ class TextDetector(object):
self.args = args self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': None
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
'std': [0.229, 0.224, 0.225], 'std': [0.229, 0.224, 0.225],
...@@ -183,7 +180,7 @@ class TextDetector(object): ...@@ -183,7 +180,7 @@ class TextDetector(object):
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if self.det_algorithm == "SAST" and self.det_sast_polygon:
......
...@@ -237,7 +237,7 @@ class TextRecognizer(object): ...@@ -237,7 +237,7 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
self.predictor.try_shrink_memory()
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
...@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger): ...@@ -128,7 +128,8 @@ def create_predictor(args, mode, logger):
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# config.enable_memory_optim() # enable memory optim
config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
......
...@@ -97,7 +97,7 @@ def main(): ...@@ -97,7 +97,7 @@ def main():
preds = model(images) preds = model(images)
post_result = post_process_class(preds, shape_list) post_result = post_process_class(preds, shape_list)
boxes = post_result[0]['points'] boxes = post_result[0]['points']
# write resule # write result
dt_boxes_json = [] dt_boxes_json = []
for box in boxes: for box in boxes:
tmp_json = {"transcription": ""} tmp_json = {"transcription": ""}
......
...@@ -237,8 +237,9 @@ def train(config, ...@@ -237,8 +237,9 @@ def train(config,
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if dist.get_rank( if dist.get_rank() == 0 and (
) == 0 and global_step > 0 and global_step % print_batch_step == 0: (global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
......
...@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer): ...@@ -52,7 +52,10 @@ def main(config, device, logger, vdl_writer):
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0: if len(train_dataloader) == 0:
logger.error( logger.error(
'No Images in train dataset, please check annotation file and path in the configuration file' "No Images in train dataset, please ensure\n" +
"\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+
"\t2. The annotation file and path in the configuration file are provided normally."
) )
return return
...@@ -89,8 +92,10 @@ def main(config, device, logger, vdl_writer): ...@@ -89,8 +92,10 @@ def main(config, device, logger, vdl_writer):
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
format(len(train_dataloader), len(valid_dataloader))) if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader)))
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment