Commit 404a3b31 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents a5530565 d128c1df
......@@ -129,3 +129,9 @@ PaddleOCR主要聚焦通用OCR,如果有垂类需求,您可以用PaddleOCR+
A:识别模型训练初期acc为0是正常的,多训一段时间指标就上来了。
***
具体的训练教程可点击下方链接跳转:
- [文本检测模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
- [文本识别模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
- [文本方向分类器训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
\ No newline at end of file
# Configuration
# Configuration
- [1. Optional Parameter List](#1-optional-parameter-list)
- [2. Intorduction to Global Parameters of Configuration File](#2-intorduction-to-global-parameters-of-configuration-file)
......@@ -37,9 +37,8 @@ Take rec_chinese_lite_train_v2.0.yml as an example
| checkpoints | set model parameter path | None | Used to load parameters after interruption to continue training|
| use_visualdl | Set whether to enable visualdl for visual log display | False | [Tutorial](https://www.paddlepaddle.org.cn/paddle/visualdl) |
| infer_img | Set inference image path or folder path | ./infer_img | \|
| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | \ |
| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | If the character_dict_path is None, model can only recognize number and lower letters |
| max_text_length | Set the maximum length of text | 25 | \ |
| character_type | Set character type | ch | en/ch, the default dict will be used for en, and the custom dict will be used for ch |
| use_space_char | Set whether to recognize spaces | True | Only support in character_type=ch mode |
| label_list | Set the angle supported by the direction classifier | ['0','180'] | Only valid in angle classifier model |
| save_res_path | Set the save address of the test model results | ./output/det_db/predicts_db.txt | Only valid in the text detection model |
......@@ -196,40 +195,39 @@ Italian is made up of Latin letters, so after executing the command, you will ge
use_gpu: True
epoch_num: 500
...
character_type: it # language
character_dict_path: {path/of/dict} # path of dict
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/ # root directory of training data
label_file_list: ["./train_data/train_list.txt"] # train label path
...
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/ # root directory of val data
label_file_list: ["./train_data/val_list.txt"] # val label path
...
```
Currently, the multi-language algorithms supported by PaddleOCR are:
| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
......
# Environment Preparation
Recommended working environment:
- PaddlePaddle >= 2.0.0 (2.1.2)
- python3.7
- CUDA10.1 / CUDA10.2
- CUDNN 7.6
* [1. Python Environment Setup](#1)
+ [1.1 Windows](#1.1)
+ [1.2 Mac](#1.2)
+ [1.3 Linux](#1.3)
* [2. Install PaddlePaddle 2.0](#2)
<a name="1"></a>
## 1. Python Environment Setup
......@@ -38,7 +45,7 @@
- Check conda to add environment variables and ignore the warning that
<img src="../install/windows/anaconda_install_env.png" alt="add conda to path" width="500" align="center"/>
#### 1.1.2 Opening the terminal and creating the conda environment
......@@ -69,7 +76,7 @@
# View the current location of python
where python
```
<img src="../install/windows/conda_list_env.png" alt="create environment" width="600" align="center"/>
The above anaconda environment and python environment are installed
......@@ -133,13 +140,13 @@ The above anaconda environment and python environment are installed
# !!! Contents within this block are managed by 'conda init' !!!
__conda_setup="$('/Users/xxx/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
eval "$__conda_setup"
eval "$__conda_setup"
else
if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
. "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
else
export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
fi
if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
. "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
else
export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
fi
fi
unset __conda_setup
# <<< conda initialize <<<
......@@ -197,11 +204,10 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
- **Download Anaconda**.
- Download at: https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?C=M&O=D
<img src="../install/linux/anaconda_download.png" akt="anaconda download" width="800" align="center"/>
- Select the appropriate version for your operating system
- Type `uname -m` in the terminal to check the command set used by your system
......@@ -216,12 +222,12 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
sudo yum install wget # CentOS
```
```bash
# Then use wget to download from Tsinghua source
# Then use wget to download from Tsinghua source
# If you want to download Anaconda3-2021.05-Linux-x86_64.sh, the download command is as follows
wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2021.05-Linux-x86_64.sh
# If you want to download another version, you need to change the file name after the last 1 / to the version you want to download
```
- To install Anaconda.
- Type `sh Anaconda3-2021.05-Linux-x86_64.sh` at the command line
......@@ -309,7 +315,18 @@ cd /home/Projects
# Create a docker container named ppocr and map the current directory to the /paddle directory of the container
# If using CPU, use docker instead of nvidia-docker to create docker
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
# If using GPU, use nvidia-docker to create docker
# docker image registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda11.2-cudnn8 is recommended for CUDA11.2 + CUDNN8.
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
```
You can also visit [DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/) to get the image that fits your machine.
```
# ctrl+P+Q to exit docker, to re-enter docker using the following command:
sudo docker container exec -it ppocr /bin/bash
```
<a name="2"></a>
......@@ -329,4 +346,3 @@ python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
```
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
......@@ -21,7 +21,7 @@ Next, we first introduce how to convert a trained model into an inference model,
- [2.2 DB Text Detection Model Inference](#DB_DETECTION)
- [2.3 East Text Detection Model Inference](#EAST_DETECTION)
- [2.4 Sast Text Detection Model Inference](#SAST_DETECTION)
- [3. Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
- [3.1 Lightweight Chinese Text Recognition Model Reference](#LIGHTWEIGHT_RECOGNITION)
- [3.2 CTC-Based Text Recognition Model Inference](#CTC-BASED_RECOGNITION)
......@@ -281,7 +281,7 @@ python3 tools/export_model.py -c configs/det/rec_r34_vd_none_bilstm_ctc.yml -o G
For CRNN text recognition model inference, execute the following commands:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
![](../imgs_words_en/word_336.png)
......@@ -314,7 +314,7 @@ with the training, such as: --rec_image_shape="1, 64, 256"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
--rec_char_type="en" \
--rec_char_dict_path="./ppocr/utils/ic15_dict.txt" \
--rec_algorithm="SRN"
```
......@@ -323,7 +323,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
......@@ -333,7 +333,7 @@ If you need to predict other language models, when using inference model predict
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
......@@ -399,7 +399,7 @@ If you want to try other detection algorithms or recognition algorithms, please
The following command uses the combination of the EAST text detection and STAR-Net text recognition:
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
After executing the command, the recognition result image is as follows:
......
......@@ -161,7 +161,7 @@ The current multi-language model is still in the demo stage and will continue to
If you like, you can submit the dictionary file to [dict](../../ppocr/utils/dict) and we will thank you in the Repo.
To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` and set `character_type` to `ch`.
To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` .
- Custom dictionary
......@@ -172,8 +172,6 @@ If you need to customize dic file, please add character_dict_path field in confi
If you want to support the recognition of the `space` category, please set the `use_space_char` field in the yml file to `True`.
**Note: use_space_char only takes effect when character_type=ch**
<a name="TRAINING"></a>
## 2.Training
......@@ -250,7 +248,6 @@ Global:
# Add a custom dictionary, such as modify the dictionary, please point the path to the new dictionary
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
# Modify character type
character_type: ch
...
# Whether to recognize spaces
use_space_char: True
......@@ -312,18 +309,18 @@ Eval:
Currently, the multi-language algorithms supported by PaddleOCR are:
| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
......@@ -471,6 +468,3 @@ inference/det_db/
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
```
......@@ -147,3 +147,9 @@ There are several experiences for reference when constructing the data set:
A: It is normal for the acc to be 0 at the beginning of the recognition model training, and the indicator will come up after a longer training period.
***
Click the following links for detailed training tutorial:
- [text detection model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
- [text recognition model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
- [text direction classification model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
......@@ -21,6 +21,8 @@ import numpy as np
import string
import json
from ppocr.utils.logging import get_logger
class ClsLabelEncode(object):
def __init__(self, label_list, **kwargs):
......@@ -92,31 +94,23 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False):
support_character_type = [
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en":
self.lower = False
if character_dict_path is None:
logger = get_logger()
logger.warning(
"The character_dict_path is None, model can only recognize number and lower letters"
)
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.lower = True
else:
self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
......@@ -125,7 +119,6 @@ class BaseRecLabelEncode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
......@@ -147,7 +140,7 @@ class BaseRecLabelEncode(object):
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
if self.character_type == "en":
if self.lower:
text = text.lower()
text_list = []
for char in text:
......@@ -167,13 +160,11 @@ class NRTRLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN_symbol',
use_space_char=False,
**kwargs):
super(NRTRLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(NRTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
......@@ -200,12 +191,10 @@ class CTCLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(CTCLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(CTCLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
......@@ -231,12 +220,10 @@ class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncodeTest,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(E2ELabelEncodeTest, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
import json
......@@ -305,12 +292,10 @@ class AttnLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(AttnLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(AttnLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
......@@ -353,12 +338,10 @@ class SEEDLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(SEEDLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.end_str = "eos"
......@@ -385,12 +368,10 @@ class SRNLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(SRNLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
......@@ -598,12 +579,10 @@ class SARLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SARLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
super(SARLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
......
......@@ -87,17 +87,17 @@ class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_type='ch',
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_type = character_type
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_type == "ch":
if self.infer_mode and self.character_dict_path is not None:
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape, self.padding)
......
......@@ -32,6 +32,7 @@ class ACELoss(nn.Layer):
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
......@@ -42,9 +43,7 @@ class ACELoss(nn.Layer):
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss}
......@@ -27,7 +27,6 @@ class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def __init__(self,
num_classes=6625,
feat_dim=96,
......@@ -37,8 +36,7 @@ class CenterLoss(nn.Layer):
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
shape=[self.num_classes, self.feat_dim]).astype("float64")
if init_center:
assert os.path.exists(
......@@ -60,22 +58,23 @@ class CenterLoss(nn.Layer):
batch_size = feats_reshape.shape[0]
#calc feat * feat
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#calc l2 distance between feats and centers
square_feat = paddle.sum(paddle.square(feats_reshape),
axis=1,
keepdim=True)
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
square_center = paddle.sum(paddle.square(self.centers),
axis=1,
keepdim=True)
square_center = paddle.expand(
square_center, [self.num_classes, batch_size]).astype("float64")
square_center = paddle.transpose(square_center, [1, 0])
#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
distmat = paddle.add(square_feat, square_center)
feat_dot_center = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * feat_dot_center
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
......@@ -83,7 +82,8 @@ class CenterLoss(nn.Layer):
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64") #get mask
label).astype("float64")
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,26 +16,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal
import math
import numpy as np
import paddle
from paddle import ParamAttr, reshape, transpose, concat, split
from paddle import ParamAttr, reshape, transpose
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal
import math
from paddle.nn.functional import hardswish, hardsigmoid
from paddle.regularizer import L2Decay
from paddle.nn.functional import hardswish, hardsigmoid
class ConvBNLayer(nn.Layer):
......
......@@ -21,33 +21,15 @@ import re
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False):
support_character_type = [
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en":
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
......@@ -57,9 +39,6 @@ class BaseRecLabelDecode(object):
self.character_str.append(" ")
dict_character = list(self.character_str)
else:
raise NotImplementedError
self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
......@@ -102,13 +81,10 @@ class BaseRecLabelDecode(object):
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
......@@ -136,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
......@@ -162,13 +137,9 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='EN_symbol',
use_space_char=True,
**kwargs):
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
......@@ -230,13 +201,10 @@ class NRTRLabelDecode(BaseRecLabelDecode):
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
......@@ -313,13 +281,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
......@@ -394,13 +359,10 @@ class SEEDLabelDecode(BaseRecLabelDecode):
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='en',
use_space_char=False,
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
......@@ -616,13 +578,10 @@ class TableLabelDecode(object):
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
......
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
\ No newline at end of file
# 从训练到推理部署工具链测试方法介绍
test.sh和params.txt文件配合使用,完成OCR轻量检测和识别模型从训练到预测的流程测试。
# 安装依赖
- 安装PaddlePaddle >= 2.0
- 安装PaddleOCR依赖
```
pip3 install -r ../requirements.txt
```
- 安装autolog
```
git clone https://github.com/LDOUBLEV/AutoLog
cd AutoLog
pip3 install -r requirements.txt
python3 setup.py bdist_wheel
pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
cd ../
```
# 目录介绍
```bash
tests/
├── ocr_det_params.txt # 测试OCR检测模型的参数配置文件
├── ocr_rec_params.txt # 测试OCR识别模型的参数配置文件
├── ocr_ppocr_mobile_params.txt # 测试OCR检测+识别模型串联的参数配置文件
└── prepare.sh # 完成test.sh运行所需要的数据和模型下载
└── test.sh # 测试主程序
```
# 使用方法
test.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
- 模式1:lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
```shell
bash tests/prepare.sh ./tests/ocr_det_params.txt 'lite_train_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'lite_train_infer'
```
- 模式2:whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
```shell
bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'whole_infer'
```
- 模式3:infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
```shell
bash tests/prepare.sh ./tests/ocr_det_params.txt 'infer'
# 用法1:
bash tests/test.sh ./tests/ocr_det_params.txt 'infer'
# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
bash tests/test.sh ./tests/ocr_det_params.txt 'infer' '1'
```
- 模式4:whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度;
```shell
bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_train_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'whole_train_infer'
```
- 模式5:cpp_infer , CE: 验证inference model的c++预测是否走通;
```shell
bash tests/prepare.sh ./tests/ocr_det_params.txt 'cpp_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'cpp_infer'
```
# 日志输出
最终在```tests/output```目录下生成.log后缀的日志文件
#!/bin/bash
source tests/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
# parser params
IFS=$'\n'
lines=(${dataline})
# The training params
model_name=$(func_parser_value "${lines[1]}")
python=$(func_parser_value "${lines[2]}")
gpu_list=$(func_parser_value "${lines[3]}")
train_use_gpu_key=$(func_parser_key "${lines[4]}")
train_use_gpu_value=$(func_parser_value "${lines[4]}")
autocast_list=$(func_parser_value "${lines[5]}")
autocast_key=$(func_parser_key "${lines[5]}")
epoch_key=$(func_parser_key "${lines[6]}")
epoch_num=$(func_parser_params "${lines[6]}")
save_model_key=$(func_parser_key "${lines[7]}")
train_batch_key=$(func_parser_key "${lines[8]}")
train_batch_value=$(func_parser_params "${lines[8]}")
pretrain_model_key=$(func_parser_key "${lines[9]}")
pretrain_model_value=$(func_parser_value "${lines[9]}")
train_model_name=$(func_parser_value "${lines[10]}")
train_infer_img_dir=$(func_parser_value "${lines[11]}")
train_param_key1=$(func_parser_key "${lines[12]}")
train_param_value1=$(func_parser_value "${lines[12]}")
trainer_list=$(func_parser_value "${lines[14]}")
trainer_norm=$(func_parser_key "${lines[15]}")
norm_trainer=$(func_parser_value "${lines[15]}")
pact_key=$(func_parser_key "${lines[16]}")
pact_trainer=$(func_parser_value "${lines[16]}")
fpgm_key=$(func_parser_key "${lines[17]}")
fpgm_trainer=$(func_parser_value "${lines[17]}")
distill_key=$(func_parser_key "${lines[18]}")
distill_trainer=$(func_parser_value "${lines[18]}")
trainer_key1=$(func_parser_key "${lines[19]}")
trainer_value1=$(func_parser_value "${lines[19]}")
trainer_key2=$(func_parser_key "${lines[20]}")
trainer_value2=$(func_parser_value "${lines[20]}")
eval_py=$(func_parser_value "${lines[23]}")
eval_key1=$(func_parser_key "${lines[24]}")
eval_value1=$(func_parser_value "${lines[24]}")
save_infer_key=$(func_parser_key "${lines[27]}")
export_weight=$(func_parser_key "${lines[28]}")
norm_export=$(func_parser_value "${lines[29]}")
pact_export=$(func_parser_value "${lines[30]}")
fpgm_export=$(func_parser_value "${lines[31]}")
distill_export=$(func_parser_value "${lines[32]}")
export_key1=$(func_parser_key "${lines[33]}")
export_value1=$(func_parser_value "${lines[33]}")
export_key2=$(func_parser_key "${lines[34]}")
export_value2=$(func_parser_value "${lines[34]}")
# parser inference model
infer_model_dir_list=$(func_parser_value "${lines[36]}")
infer_export_list=$(func_parser_value "${lines[37]}")
infer_is_quant=$(func_parser_value "${lines[38]}")
# parser inference
inference_py=$(func_parser_value "${lines[39]}")
use_gpu_key=$(func_parser_key "${lines[40]}")
use_gpu_list=$(func_parser_value "${lines[40]}")
use_mkldnn_key=$(func_parser_key "${lines[41]}")
use_mkldnn_list=$(func_parser_value "${lines[41]}")
cpu_threads_key=$(func_parser_key "${lines[42]}")
cpu_threads_list=$(func_parser_value "${lines[42]}")
batch_size_key=$(func_parser_key "${lines[43]}")
batch_size_list=$(func_parser_value "${lines[43]}")
use_trt_key=$(func_parser_key "${lines[44]}")
use_trt_list=$(func_parser_value "${lines[44]}")
precision_key=$(func_parser_key "${lines[45]}")
precision_list=$(func_parser_value "${lines[45]}")
infer_model_key=$(func_parser_key "${lines[46]}")
image_dir_key=$(func_parser_key "${lines[47]}")
infer_img_dir=$(func_parser_value "${lines[47]}")
save_log_key=$(func_parser_key "${lines[48]}")
benchmark_key=$(func_parser_key "${lines[49]}")
benchmark_value=$(func_parser_value "${lines[49]}")
infer_key1=$(func_parser_key "${lines[50]}")
infer_value1=$(func_parser_value "${lines[50]}")
LOG_PATH="./tests/output"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_python.log"
function func_inference(){
IFS='|'
_python=$1
_script=$2
_model_dir=$3
_log_path=$4
_img_dir=$5
_flag_quant=$6
# inference
for use_gpu in ${use_gpu_list[*]}; do
if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
for use_mkldnn in ${use_mkldnn_list[*]}; do
if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
continue
fi
for threads in ${cpu_threads_list[*]}; do
for batch_size in ${batch_size_list[*]}; do
_save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}"
done
done
done
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
for use_trt in ${use_trt_list[*]}; do
for precision in ${precision_list[*]}; do
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
continue
fi
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue
fi
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
continue
fi
for batch_size in ${batch_size_list[*]}; do
_save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
set_precision=$(func_set_params "${precision_key}" "${precision}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}"
done
done
done
else
echo "Does not support hardware other than CPU and GPU Currently!"
fi
done
}
# set cuda device
GPUID=$2
if [ ${#GPUID} -le 0 ];then
env=" "
else
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
fi
set CUDA_VISIBLE_DEVICES
eval $env
echo "################### run test ###################"
......@@ -131,14 +131,9 @@ def main(args):
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
except:
except Exception as E:
logger.info(traceback.format_exc())
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
......
......@@ -38,40 +38,34 @@ logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
'name': 'CTCLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == 'NRTR':
postprocess_params = {
'name': 'NRTRLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
'name': 'SARLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
......
......@@ -74,7 +74,6 @@ def init_args():
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
......@@ -268,10 +267,11 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if args.precision == "fp16":
config.enable_mkldnn_bfloat16()
# enable memory optim
config.enable_memory_optim()
#config.disable_glog_info()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if mode == 'table':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment