Commit c94428a8 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix conflicts

parents 390f240b 7efa3975
......@@ -9,9 +9,9 @@
### 1.文本检测算法
PaddleOCR开源的文本检测算法列表:
- [x] DB([paper]( https://arxiv.org/abs/1911.08947) )(ppocr推荐)
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
- [x] DB([paper]( https://arxiv.org/abs/1911.08947)) [2](ppocr推荐)
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[1]
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
在ICDAR2015文本检测公开数据集上,算法效果如下:
......@@ -38,13 +38,13 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
### 2.文本识别算法
PaddleOCR基于动态图开源的文本识别算法列表:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717) )(ppocr推荐)
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) coming soon
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294)) coming soon
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] coming soon
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
|-|-|-|-|-|
......
......@@ -117,7 +117,7 @@ python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/
```
# 预测分类结果
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
```
预测图片:
......
......@@ -120,16 +120,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
测试单张图像的检测效果
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
测试DB模型时,调整后处理阈值,
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
测试文件夹下所有图像的检测效果
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
......@@ -245,7 +245,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
超轻量中文识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/"
# 下载超轻量中文识别模型:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="ch_ppocr_mobile_v2.0_rec_infer"
```
![](../imgs_words/ch/word_4.jpg)
......@@ -266,7 +269,6 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153)
```
python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn
```
CRNN 文本识别模型推理,可以执行如下命令:
......@@ -327,7 +329,10 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
方向分类模型推理,可以执行如下命令:
```
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/"
# 下载超轻量中文方向分类器模型:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="ch_ppocr_mobile_v2.0_cls_infer"
```
![](../imgs_words/ch/word_1.jpg)
......
......@@ -324,7 +324,6 @@ Eval:
评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
```
# GPU 评估, Global.checkpoints 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
......@@ -342,7 +341,7 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
```
# 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/en/word_1.png
```
预测图片:
......@@ -361,7 +360,7 @@ infer_img: doc/imgs_words/en/word_1.png
```
# 预测中文结果
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
```
预测图片:
......
......@@ -11,11 +11,12 @@
}
2. DB:
@article{liao2019real,
title={Real-time Scene Text Detection with Differentiable Binarization},
@inproceedings{liao2020real,
title={Real-Time Scene Text Detection with Differentiable Binarization.},
author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang},
journal={arXiv preprint arXiv:1911.08947},
year={2019}
booktitle={AAAI},
pages={11474--11481},
year={2020}
}
3. DTRB:
......@@ -37,10 +38,11 @@
}
5. SRN:
@article{yu2020towards,
title={Towards Accurate Scene Text Recognition with Semantic Reasoning Networks},
author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Han, Junyu and Liu, Jingtuo and Ding, Errui},
journal={arXiv preprint arXiv:2003.12294},
@inproceedings{yu2020towards,
title={Towards accurate scene text recognition with semantic reasoning networks},
author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Liu, Tao and Han, Junyu and Liu, Jingtuo and Ding, Errui},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={12113--12122},
year={2020}
}
......@@ -52,4 +54,62 @@
pages={9086--9095},
year={2019}
}
7. CRNN:
@article{shi2016end,
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
journal={IEEE transactions on pattern analysis and machine intelligence},
volume={39},
number={11},
pages={2298--2304},
year={2016},
publisher={IEEE}
}
8. FPGM:
@inproceedings{he2019filter,
title={Filter pruning via geometric median for deep convolutional neural networks acceleration},
author={He, Yang and Liu, Ping and Wang, Ziwei and Hu, Zhilan and Yang, Yi},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={4340--4349},
year={2019}
}
9. PACT:
@article{choi2018pact,
title={Pact: Parameterized clipping activation for quantized neural networks},
author={Choi, Jungwook and Wang, Zhuo and Venkataramani, Swagath and Chuang, Pierce I-Jen and Srinivasan, Vijayalakshmi and Gopalakrishnan, Kailash},
journal={arXiv preprint arXiv:1805.06085},
year={2018}
}
10.Rosetta
@inproceedings{borisyuk2018rosetta,
title={Rosetta: Large scale system for text detection and recognition in images},
author={Borisyuk, Fedor and Gordo, Albert and Sivakumar, Viswanath},
booktitle={Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining},
pages={71--79},
year={2018}
}
11.STAR-Net
@inproceedings{liu2016star,
title={STAR-Net: A SpaTial Attention Residue Network for Scene Text Recognition.},
author={Liu, Wei and Chen, Chaofeng and Wong, Kwan-Yee K and Su, Zhizhong and Han, Junyu},
booktitle={BMVC},
volume={2},
pages={7},
year={2016}
}
12.RARE
@inproceedings{shi2016robust,
title={Robust scene text recognition with automatic rectification},
author={Shi, Baoguang and Wang, Xinggang and Lyu, Pengyuan and Yao, Cong and Bai, Xiang},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={4168--4176},
year={2016}
}
```
......@@ -11,9 +11,9 @@ This tutorial lists the text detection algorithms and text recognition algorithm
### 1. Text Detection Algorithm
PaddleOCR open source text detection algorithms list:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498) )(Baidu Self-Research)
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[2]
- [x] DB([paper](https://arxiv.org/abs/1911.08947))[1]
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
On the ICDAR2015 dataset, the text detection result is as follows:
......@@ -39,11 +39,11 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
### 2. Text Recognition Algorithm
PaddleOCR open-source text recognition algorithms list:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) coming soon
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294) )(Baidu Self-Research) coming soon
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] coming soon
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......
......@@ -119,7 +119,7 @@ Use `Global.infer_img` to specify the path of the predicted picture or folder, a
```
# Predict English results
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_10.png
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words_en/word_10.png
```
Input image:
......
......@@ -113,16 +113,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
Test the detection result on a single image:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
When testing the DB model, adjust the post-processing threshold:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
Test the detection result on all images in the folder:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
......@@ -255,15 +255,18 @@ The following will introduce the lightweight Chinese recognition model inference
For lightweight Chinese recognition model inference, you can execute the following commands:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/"
# download CRNN text recognition inference model
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_10.png" --rec_model_dir="ch_ppocr_mobile_v2.0_rec_infer"
```
![](../imgs_words/ch/word_4.jpg)
![](../imgs_words_en/word_10.png)
After executing the command, the prediction results (recognized text and score) of the above image will be printed on the screen.
```bash
Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153)
Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.9897658)
```
<a name="CTC-BASED_RECOGNITION"></a>
......@@ -339,7 +342,12 @@ For angle classification model inference, you can execute the following commands
```
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words_en/word_10.png" --cls_model_dir="./inference/cls/"
```
```
# download text angle class inference model:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words_en/word_10.png" --cls_model_dir="ch_ppocr_mobile_v2.0_cls_infer"
```
![](../imgs_words_en/word_10.png)
After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen.
......
......@@ -317,11 +317,11 @@ Eval:
<a name="EVALUATION"></a>
### EVALUATION
The evaluation data set can be modified via `configs/rec/rec_icdar15_reader.yml` setting of `label_file_path` in EvalReader.
The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/rec/rec_icdar15_train.yml` file.
```
# GPU evaluation, Global.checkpoints is the weight to be tested
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_reader.yml -o Global.checkpoints={path/to/weights}/best_accuracy
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
<a name="PREDICTION"></a>
......@@ -336,7 +336,7 @@ The default prediction picture is stored in `infer_img`, and the weight is speci
```
# Predict English results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/en/word_1.jpg
```
Input image:
......@@ -354,7 +354,7 @@ The configuration file used for prediction must be consistent with the training.
```
# Predict Chinese results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/ch/word_1.jpg
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
```
Input image:
......
doc/joinus.PNG

272 KB | W: | H:

doc/joinus.PNG

212 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
......@@ -262,8 +262,8 @@ class PaddleOCR(predict_system.TextSystem):
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
postprocess_params.rec_char_dict_path = Path(
__file__).parent / postprocess_params.rec_char_dict_path
postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path)
# init det_model and rec_model
super().__init__(postprocess_params)
......
......@@ -45,7 +45,6 @@ class BalanceLoss(nn.Layer):
self.balance_loss = balance_loss
self.main_loss_type = main_loss_type
self.negative_ratio = negative_ratio
self.main_loss_type = main_loss_type
self.return_origin = return_origin
self.eps = eps
......
......@@ -102,7 +102,6 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
best_model_dict['start_epoch'] = best_model_dict['best_epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
......
......@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='2.0.1',
version='2.0.2',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
......@@ -71,6 +71,9 @@ class TextDetector(object):
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST":
pre_process_list[0] = {
'DetResizeForTest': {'resize_long': args.det_limit_side_len}
}
postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
......
......@@ -34,7 +34,6 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
......
......@@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
return metirc
def preprocess():
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
......@@ -350,15 +350,17 @@ def preprocess():
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(
name='root', log_file='{}/train.log'.format(save_model_dir))
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
......
......@@ -110,6 +110,6 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment