Commit 55d71902 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 1d03fd33 10b7e706
...@@ -5,5 +5,6 @@ recursive-include ppocr/utils *.txt utility.py logging.py network.py ...@@ -5,5 +5,6 @@ recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data *.py recursive-include ppocr/data *.py
recursive-include ppocr/postprocess *.py recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py recursive-include tools/infer *.py
recursive-include tools __init__.py
recursive-include ppocr/utils/e2e_utils *.py recursive-include ppocr/utils/e2e_utils *.py
recursive-include ppstructure *.py recursive-include ppstructure *.py
\ No newline at end of file
...@@ -207,6 +207,24 @@ For some data that are difficult to recognize, the recognition results will not ...@@ -207,6 +207,24 @@ For some data that are difficult to recognize, the recognition results will not
pip install opencv-contrib-python-headless==4.2.0.32 pip install opencv-contrib-python-headless==4.2.0.32
``` ```
### Dataset division
- Enter the following command in the terminal to execute the dataset division script:
```
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
```
- Parameter Description:
trainValTestRatio is the division ratio of the number of images in the training set, validation set, and test set, set according to your actual situation, the default is 6:2:2
labelRootPath is the storage path of the dataset labeled by PPOCRLabel, the default is ../train_data/label
detRootPath is the path where the text detection dataset is divided according to the dataset marked by PPOCRLabel. The default is ../train_data/det
recRootPath is the path where the character recognition dataset is divided according to the dataset marked by PPOCRLabel. The default is ../train_data/rec
### Related ### Related
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg) 1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
\ No newline at end of file
...@@ -193,7 +193,23 @@ PPOCRLabel支持三种导出方式: ...@@ -193,7 +193,23 @@ PPOCRLabel支持三种导出方式:
``` ```
pip install opencv-contrib-python-headless==4.2.0.32 pip install opencv-contrib-python-headless==4.2.0.32
``` ```
### 数据集划分
- 在终端中输入以下命令执行数据集划分脚本:
```
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
```
- 参数说明:
trainValTestRatio是训练集、验证集、测试集的图像数量划分比例,根据你的实际情况设定,默认是6:2:2
labelRootPath是PPOCRLabel标注的数据集存放路径,默认是../train_data/label
detRootPath是根据PPOCRLabel标注的数据集划分后的文本检测数据集存放的路径,默认是../train_data/det
recRootPath是根据PPOCRLabel标注的数据集划分后的字符识别数据集存放的路径,默认是../train_data/rec
### 4. 参考资料 ### 4. 参考资料
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg) 1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
# coding:utf8
import os
import shutil
import random
import argparse
# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
def isCreateOrDeleteFolder(path, flag):
flagPath = os.path.join(path, flag)
if os.path.exists(flagPath):
shutil.rmtree(flagPath)
os.makedirs(flagPath)
flagAbsPath = os.path.abspath(flagPath)
return flagAbsPath
def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
# 按照指定的比例划分训练集、验证集、测试集
labelPath = os.path.join(root, dir)
labelAbsPath = os.path.abspath(labelPath)
if flag == "det":
labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName)
elif flag == "rec":
labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName)
labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
labelFileContent = labelFileRead.readlines()
random.shuffle(labelFileContent)
labelRecordLen = len(labelFileContent)
for index, labelRecordInfo in enumerate(labelFileContent):
imageRelativePath = labelRecordInfo.split('\t')[0]
imageLabel = labelRecordInfo.split('\t')[1]
imageName = os.path.basename(imageRelativePath)
if flag == "det":
imagePath = os.path.join(labelAbsPath, imageName)
elif flag == "rec":
imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
# 按预设的比例划分训练集、验证集、测试集
trainValTestRatio = args.trainValTestRatio.split(":")
trainRatio = eval(trainValTestRatio[0]) / 10
valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
curRatio = index / labelRecordLen
if curRatio < trainRatio:
imageCopyPath = os.path.join(absTrainRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
elif curRatio >= trainRatio and curRatio < valRatio:
imageCopyPath = os.path.join(absValRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
else:
imageCopyPath = os.path.join(absTestRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
# 删掉存在的文件
def removeFile(path):
if os.path.exists(path):
os.remove(path)
def genDetRecTrainVal(args):
detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")
removeFile(os.path.join(args.detRootPath, "train.txt"))
removeFile(os.path.join(args.detRootPath, "val.txt"))
removeFile(os.path.join(args.detRootPath, "test.txt"))
removeFile(os.path.join(args.recRootPath, "train.txt"))
removeFile(os.path.join(args.recRootPath, "val.txt"))
removeFile(os.path.join(args.recRootPath, "test.txt"))
detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
for root, dirs, files in os.walk(args.labelRootPath):
for dir in dirs:
splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
break
if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
# 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
parser = argparse.ArgumentParser()
parser.add_argument(
"--trainValTestRatio",
type=str,
default="6:2:2",
help="ratio of trainset:valset:testset")
parser.add_argument(
"--labelRootPath",
type=str,
default="../train_data/label",
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
)
parser.add_argument(
"--detRootPath",
type=str,
default="../train_data/det",
help="the path where the divided detection dataset is placed")
parser.add_argument(
"--recRootPath",
type=str,
default="../train_data/rec",
help="the path where the divided recognition dataset is placed"
)
parser.add_argument(
"--detLabelFileName",
type=str,
default="Label.txt",
help="the name of the detection annotation file")
parser.add_argument(
"--recLabelFileName",
type=str,
default="rec_gt.txt",
help="the name of the recognition annotation file"
)
parser.add_argument(
"--recImageDirName",
type=str,
default="crop_img",
help="the name of the folder where the cropped recognition dataset is located"
)
args = parser.parse_args()
genDetRecTrainVal(args)
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [4000, 5000] eval_batch_step: [4000, 5000]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
...@@ -106,4 +106,4 @@ Eval: ...@@ -106,4 +106,4 @@ Eval:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 1 # must be 1 batch_size_per_card: 1 # must be 1
num_workers: 2 num_workers: 2
\ No newline at end of file
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [4000, 5000] eval_batch_step: [4000, 5000]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
......
Global: Global:
use_gpu: true use_gpu: true
epoch_num: 50 epoch_num: 400
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 5 save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/table/table.jpg
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 100
max_elem_length: 500 max_elem_length: 800
max_cell_num: 500 max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
Optimizer: Optimizer:
name: Adam name: Adam
beta1: 0.9 beta1: 0.9
...@@ -41,13 +40,15 @@ Architecture: ...@@ -41,13 +40,15 @@ Architecture:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV3
scale: 1.0 scale: 1.0
model_name: small model_name: large
disable_se: True
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001 l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
......
...@@ -307,21 +307,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img, ...@@ -307,21 +307,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
return filter_boxes; return filter_boxes;
} }
std::shared_ptr<PaddlePredictor> loadModel(std::string model_file, std::string power_mode, int num_threads) { std::shared_ptr<PaddlePredictor> loadModel(std::string model_file, int num_threads) {
MobileConfig config; MobileConfig config;
config.set_model_from_file(model_file); config.set_model_from_file(model_file);
if (power_mode == "LITE_POWER_HIGH"){
config.set_power_mode(LITE_POWER_HIGH);
} else {
if (power_mode == "LITE_POWER_LOW") {
config.set_power_mode(LITE_POWER_HIGH);
} else {
std::cerr << "Only support LITE_POWER_HIGH or LITE_POWER_HIGH." << std::endl;
exit(1);
}
}
config.set_threads(num_threads); config.set_threads(num_threads);
std::shared_ptr<PaddlePredictor> predictor = std::shared_ptr<PaddlePredictor> predictor =
...@@ -391,7 +380,7 @@ void check_params(int argc, char **argv) { ...@@ -391,7 +380,7 @@ void check_params(int argc, char **argv) {
if (strcmp(argv[1], "det") == 0) { if (strcmp(argv[1], "det") == 0) {
if (argc < 9){ if (argc < 9){
std::cerr << "[ERROR] usage:" << argv[0] std::cerr << "[ERROR] usage:" << argv[0]
<< " det det_model num_threads batchsize power_mode img_dir det_config lite_benchmark_value" << std::endl; << " det det_model runtime_device num_threads batchsize img_dir det_config lite_benchmark_value" << std::endl;
exit(1); exit(1);
} }
} }
...@@ -399,7 +388,7 @@ void check_params(int argc, char **argv) { ...@@ -399,7 +388,7 @@ void check_params(int argc, char **argv) {
if (strcmp(argv[1], "rec") == 0) { if (strcmp(argv[1], "rec") == 0) {
if (argc < 9){ if (argc < 9){
std::cerr << "[ERROR] usage:" << argv[0] std::cerr << "[ERROR] usage:" << argv[0]
<< " rec rec_model num_threads batchsize power_mode img_dir key_txt lite_benchmark_value" << std::endl; << " rec rec_model runtime_device num_threads batchsize img_dir key_txt lite_benchmark_value" << std::endl;
exit(1); exit(1);
} }
} }
...@@ -407,7 +396,7 @@ void check_params(int argc, char **argv) { ...@@ -407,7 +396,7 @@ void check_params(int argc, char **argv) {
if (strcmp(argv[1], "system") == 0) { if (strcmp(argv[1], "system") == 0) {
if (argc < 12){ if (argc < 12){
std::cerr << "[ERROR] usage:" << argv[0] std::cerr << "[ERROR] usage:" << argv[0]
<< " system det_model rec_model clas_model num_threads batchsize power_mode img_dir det_config key_txt lite_benchmark_value" << std::endl; << " system det_model rec_model clas_model runtime_device num_threads batchsize img_dir det_config key_txt lite_benchmark_value" << std::endl;
exit(1); exit(1);
} }
} }
...@@ -417,15 +406,15 @@ void system(char **argv){ ...@@ -417,15 +406,15 @@ void system(char **argv){
std::string det_model_file = argv[2]; std::string det_model_file = argv[2];
std::string rec_model_file = argv[3]; std::string rec_model_file = argv[3];
std::string cls_model_file = argv[4]; std::string cls_model_file = argv[4];
std::string precision = argv[5]; std::string runtime_device = argv[5];
std::string num_threads = argv[6]; std::string precision = argv[6];
std::string batchsize = argv[7]; std::string num_threads = argv[7];
std::string power_mode = argv[8]; std::string batchsize = argv[8];
std::string img_dir = argv[9]; std::string img_dir = argv[9];
std::string det_config_path = argv[10]; std::string det_config_path = argv[10];
std::string dict_path = argv[11]; std::string dict_path = argv[11];
if (strcmp(argv[5], "FP32") != 0 && strcmp(argv[5], "INT8") != 0) { if (strcmp(argv[6], "FP32") != 0 && strcmp(argv[6], "INT8") != 0) {
std::cerr << "Only support FP32 or INT8." << std::endl; std::cerr << "Only support FP32 or INT8." << std::endl;
exit(1); exit(1);
} }
...@@ -441,9 +430,9 @@ void system(char **argv){ ...@@ -441,9 +430,9 @@ void system(char **argv){
charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
charactor_dict.push_back(" "); charactor_dict.push_back(" ");
auto det_predictor = loadModel(det_model_file, power_mode, std::stoi(num_threads)); auto det_predictor = loadModel(det_model_file, std::stoi(num_threads));
auto rec_predictor = loadModel(rec_model_file, power_mode, std::stoi(num_threads)); auto rec_predictor = loadModel(rec_model_file, std::stoi(num_threads));
auto cls_predictor = loadModel(cls_model_file, power_mode, std::stoi(num_threads)); auto cls_predictor = loadModel(cls_model_file, std::stoi(num_threads));
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::cout << "The predict img: " << cv_all_img_names[i] << std::endl; std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
...@@ -477,14 +466,14 @@ void system(char **argv){ ...@@ -477,14 +466,14 @@ void system(char **argv){
void det(int argc, char **argv) { void det(int argc, char **argv) {
std::string det_model_file = argv[2]; std::string det_model_file = argv[2];
std::string precision = argv[3]; std::string runtime_device = argv[3];
std::string num_threads = argv[4]; std::string precision = argv[4];
std::string batchsize = argv[5]; std::string num_threads = argv[5];
std::string power_mode = argv[6]; std::string batchsize = argv[6];
std::string img_dir = argv[7]; std::string img_dir = argv[7];
std::string det_config_path = argv[8]; std::string det_config_path = argv[8];
if (strcmp(argv[3], "FP32") != 0 && strcmp(argv[3], "INT8") != 0) { if (strcmp(argv[4], "FP32") != 0 && strcmp(argv[4], "INT8") != 0) {
std::cerr << "Only support FP32 or INT8." << std::endl; std::cerr << "Only support FP32 or INT8." << std::endl;
exit(1); exit(1);
} }
...@@ -495,7 +484,7 @@ void det(int argc, char **argv) { ...@@ -495,7 +484,7 @@ void det(int argc, char **argv) {
//// load config from txt file //// load config from txt file
auto Config = LoadConfigTxt(det_config_path); auto Config = LoadConfigTxt(det_config_path);
auto det_predictor = loadModel(det_model_file, power_mode, std::stoi(num_threads)); auto det_predictor = loadModel(det_model_file, std::stoi(num_threads));
std::vector<double> time_info = {0, 0, 0}; std::vector<double> time_info = {0, 0, 0};
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < cv_all_img_names.size(); ++i) {
...@@ -530,14 +519,11 @@ void det(int argc, char **argv) { ...@@ -530,14 +519,11 @@ void det(int argc, char **argv) {
if (strcmp(argv[9], "True") == 0) { if (strcmp(argv[9], "True") == 0) {
AutoLogger autolog(det_model_file, AutoLogger autolog(det_model_file,
0, runtime_device,
0,
0,
std::stoi(num_threads), std::stoi(num_threads),
std::stoi(batchsize), std::stoi(batchsize),
"dynamic", "dynamic",
precision, precision,
power_mode,
time_info, time_info,
cv_all_img_names.size()); cv_all_img_names.size());
autolog.report(); autolog.report();
...@@ -546,14 +532,14 @@ void det(int argc, char **argv) { ...@@ -546,14 +532,14 @@ void det(int argc, char **argv) {
void rec(int argc, char **argv) { void rec(int argc, char **argv) {
std::string rec_model_file = argv[2]; std::string rec_model_file = argv[2];
std::string precision = argv[3]; std::string runtime_device = argv[3];
std::string num_threads = argv[4]; std::string precision = argv[4];
std::string batchsize = argv[5]; std::string num_threads = argv[5];
std::string power_mode = argv[6]; std::string batchsize = argv[6];
std::string img_dir = argv[7]; std::string img_dir = argv[7];
std::string dict_path = argv[8]; std::string dict_path = argv[8];
if (strcmp(argv[3], "FP32") != 0 && strcmp(argv[3], "INT8") != 0) { if (strcmp(argv[4], "FP32") != 0 && strcmp(argv[4], "INT8") != 0) {
std::cerr << "Only support FP32 or INT8." << std::endl; std::cerr << "Only support FP32 or INT8." << std::endl;
exit(1); exit(1);
} }
...@@ -565,7 +551,7 @@ void rec(int argc, char **argv) { ...@@ -565,7 +551,7 @@ void rec(int argc, char **argv) {
charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
charactor_dict.push_back(" "); charactor_dict.push_back(" ");
auto rec_predictor = loadModel(rec_model_file, power_mode, std::stoi(num_threads)); auto rec_predictor = loadModel(rec_model_file, std::stoi(num_threads));
std::shared_ptr<PaddlePredictor> cls_predictor; std::shared_ptr<PaddlePredictor> cls_predictor;
...@@ -603,14 +589,11 @@ void rec(int argc, char **argv) { ...@@ -603,14 +589,11 @@ void rec(int argc, char **argv) {
// TODO: support autolog // TODO: support autolog
if (strcmp(argv[9], "True") == 0) { if (strcmp(argv[9], "True") == 0) {
AutoLogger autolog(rec_model_file, AutoLogger autolog(rec_model_file,
0, runtime_device,
0,
0,
std::stoi(num_threads), std::stoi(num_threads),
std::stoi(batchsize), std::stoi(batchsize),
"dynamic", "dynamic",
precision, precision,
power_mode,
time_info, time_info,
cv_all_img_names.size()); cv_all_img_names.size());
autolog.report(); autolog.report();
......
...@@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir): ...@@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir):
image_data = file.read() image_data = file.read()
image = cv2_to_base64(image_data) image = cv2_to_base64(image_data)
for i in range(1): for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"]) ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret) print(ret)
...@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model ...@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
...@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer): ...@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}") logger.info(f"FLOPs after pruning: {flops}")
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, None) load_model(config, model)
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}") logger.info(f"metric['hmean']: {metric['hmean']}")
......
...@@ -32,7 +32,7 @@ from ppocr.losses import build_loss ...@@ -32,7 +32,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer): ...@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
......
...@@ -28,7 +28,7 @@ from paddle.jit import to_static ...@@ -28,7 +28,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
...@@ -101,7 +101,7 @@ def main(): ...@@ -101,7 +101,7 @@ def main():
quanter = QAT(config=quant_config) quanter = QAT(config=quant_config)
quanter.quantize(model) quanter.quantize(model)
init_model(config, model) load_model(config, model)
model.eval() model.eval()
# build metric # build metric
......
...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss ...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
...@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer): ...@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
......
...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss ...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
import paddleslim import paddleslim
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
......
...@@ -101,7 +101,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \ ...@@ -101,7 +101,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID # 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
# 多机多卡训练,通过 --ips 参数设置使用的机器IP地址,通过 --gpus 参数设置使用的GPU ID # 多机多卡训练,通过 --ips 参数设置使用的机器IP地址,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \ python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
...@@ -109,14 +109,14 @@ python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1 ...@@ -109,14 +109,14 @@ python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md) 有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell ```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001 python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
``` ```
**注意:** 采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通。查看机器ip地址的命令为`ifconfig` **注意:** 采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通。另外,训练时需要在多个机器上分别启动命令。查看机器ip地址的命令为`ifconfig`
如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), 以单机单卡为例,命令如下: 如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), 以单机单卡为例,命令如下:
```shell ```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml \ python3 tools/train.py -c configs/det/det_mv3_db.yml \
......
...@@ -420,3 +420,5 @@ im_show.save('result.jpg') ...@@ -420,3 +420,5 @@ im_show.save('result.jpg')
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE | | cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
| show_log | 是否打印det和rec等信息 | FALSE | | show_log | 是否打印det和rec等信息 | FALSE |
| type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr | | type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr |
| ocr_version | OCR模型版本,可选PP-OCRv2, PP-OCR。PP-OCRv2 目前仅支持中文的检测和识别模型,PP-OCR支持中文的检测,识别,多语种识别,方向分类器等模型 | PP-OCRv2 |
| structure_version | 表格结构化模型版本,可选 STRUCTURE。STRUCTURE支持表格结构化模型 | STRUCTURE |
...@@ -98,14 +98,14 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o \ ...@@ -98,14 +98,14 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o \
# multi-GPU training # multi-GPU training
# Set the GPU ID used by the '--gpus' parameter. # Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
# multi-Node, multi-GPU training # multi-Node, multi-GPU training
# Set the IPs of your nodes used by the '--ips' parameter. Set the GPU ID used by the '--gpus' parameter. # Set the IPs of your nodes used by the '--ips' parameter. Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \ python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
``` ```
**Note:** For multi-Node multi-GPU training, you need to replace the `ips` value in the preceding command with the address of your machine, and the machines must be able to ping each other. The command for viewing the IP address of the machine is `ifconfig`. **Note:** For multi-Node multi-GPU training, you need to replace the `ips` value in the preceding command with the address of your machine, and the machines must be able to ping each other. In addition, it requires activating commands separately on multiple machines when we start the training. The command for viewing the IP address of the machine is `ifconfig`.
If you want to further speed up the training, you can use [automatic mixed precision training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_en.html). for single card training, the command is as follows: If you want to further speed up the training, you can use [automatic mixed precision training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_en.html). for single card training, the command is as follows:
``` ```
python3 tools/train.py -c configs/det/det_mv3_db.yml \ python3 tools/train.py -c configs/det/det_mv3_db.yml \
......
...@@ -366,4 +366,6 @@ im_show.save('result.jpg') ...@@ -366,4 +366,6 @@ im_show.save('result.jpg')
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
| show_log | Whether to print log in det and rec | FALSE | | show_log | Whether to print log in det and rec | FALSE |
| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr | | type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr |
\ No newline at end of file | ocr_version | OCR Model version number, the current model support list is as follows: PP-OCRv2 support Chinese detection and recognition model, PP-OCR support Chinese detection, recognition and direction classifier, multilingual recognition model | PP-OCRv2 |
| structure_version | table structure Model version number, the current model support list is as follows: STRUCTURE support english table structure model | STRUCTURE |
...@@ -16,6 +16,9 @@ import os ...@@ -16,6 +16,9 @@ import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(__file__)
import paddle
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import cv2 import cv2
...@@ -29,7 +32,7 @@ from ppocr.utils.logging import get_logger ...@@ -29,7 +32,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, str2bool from tools.infer.utility import draw_ocr, str2bool, check_gpu
from ppstructure.utility import init_args, draw_structure_result from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res from ppstructure.predict_system import OCRSystem, save_structure_res
...@@ -39,130 +42,137 @@ __all__ = [ ...@@ -39,130 +42,137 @@ __all__ = [
] ]
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.2.1' VERSION = '2.3.0.1'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_MODEL_VERSION = '2.0' DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
MODEL_URLS = { MODEL_URLS = {
'2.1': { 'OCR': {
'det': { 'PP-OCRv2': {
'ch': { 'det': {
'url': 'ch': {
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar', 'url':
}, 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar',
}, },
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}
}
},
'2.0': {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
}, },
'structure': { 'rec': {
'url': 'ch': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar' 'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}
} }
}, },
'rec': { DEFAULT_OCR_MODEL_VERSION: {
'ch': { 'det': {
'url': 'ch': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'url':
'dict_path': './ppocr/utils/ppocr_keys_v1.txt' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
}, },
'en': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt' },
}, 'structure': {
'french': { 'url':
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', }
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
}, },
'te': { 'rec': {
'url': 'ch': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'url':
'dict_path': './ppocr/utils/dict/te_dict.txt' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
}
}, },
'ka': { 'cls': {
'url': 'ch': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'url':
'dict_path': './ppocr/utils/dict/ka_dict.txt' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
}
}, },
'latin': { }
'url': },
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'STRUCTURE': {
'dict_path': './ppocr/utils/dict/latin_dict.txt' DEFAULT_STRUCTURE_MODEL_VERSION: {
}, 'table': {
'arabic': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt' 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}, }
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
}
},
'cls': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
}
},
'table': {
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
} }
} }
} }
...@@ -177,7 +187,20 @@ def parse_args(mMain=True): ...@@ -177,7 +187,20 @@ def parse_args(mMain=True):
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--type", type=str, default='ocr') parser.add_argument("--type", type=str, default='ocr')
parser.add_argument("--version", type=str, default='2.1') parser.add_argument(
"--ocr_version",
type=str,
default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. '
'2. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.'
)
parser.add_argument(
"--structure_version",
type=str,
default='STRUCTURE',
help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.')
for action in parser._actions: for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']: if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
...@@ -215,9 +238,9 @@ def parse_lang(lang): ...@@ -215,9 +238,9 @@ def parse_lang(lang):
lang = "cyrillic" lang = "cyrillic"
elif lang in devanagari_lang: elif lang in devanagari_lang:
lang = "devanagari" lang = "devanagari"
assert lang in MODEL_URLS[DEFAULT_MODEL_VERSION][ assert lang in MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION][
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
MODEL_URLS[DEFAULT_MODEL_VERSION]['rec'].keys(), lang) MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION]['rec'].keys(), lang)
if lang == "ch": if lang == "ch":
det_lang = "ch" det_lang = "ch"
elif lang == 'structure': elif lang == 'structure':
...@@ -227,33 +250,41 @@ def parse_lang(lang): ...@@ -227,33 +250,41 @@ def parse_lang(lang):
return lang, det_lang return lang, det_lang
def get_model_config(version, model_type, lang): def get_model_config(type, version, model_type, lang):
if version not in MODEL_URLS: if type == 'OCR':
logger.warning('version {} not in {}, use version {} instead'.format( DEFAULT_MODEL_VERSION = DEFAULT_OCR_MODEL_VERSION
version, MODEL_URLS.keys(), DEFAULT_MODEL_VERSION)) elif type == 'STRUCTURE':
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else:
raise NotImplementedError
model_urls = MODEL_URLS[type]
if version not in model_urls:
logger.warning('version {} not in {}, auto switch to version {}'.format(
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
if model_type not in MODEL_URLS[version]: if model_type not in model_urls[version]:
if model_type in MODEL_URLS[DEFAULT_MODEL_VERSION]: if model_type in model_urls[DEFAULT_MODEL_VERSION]:
logger.warning( logger.warning(
'version {} not support {} models, use version {} instead'. 'version {} not support {} models, auto switch to version {}'.
format(version, model_type, DEFAULT_MODEL_VERSION)) format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
else: else:
logger.error('{} models is not support, we only support {}'.format( logger.error('{} models is not support, we only support {}'.format(
model_type, MODEL_URLS[DEFAULT_MODEL_VERSION].keys())) model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1) sys.exit(-1)
if lang not in MODEL_URLS[version][model_type]: if lang not in model_urls[version][model_type]:
if lang in MODEL_URLS[DEFAULT_MODEL_VERSION][model_type]: if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
logger.warning('lang {} is not support in {}, use {} instead'. logger.warning(
format(lang, version, DEFAULT_MODEL_VERSION)) 'lang {} is not support in {}, auto switch to version {}'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
else: else:
logger.error( logger.error(
'lang {} is not support, we only support {} for {} models'. 'lang {} is not support, we only support {} for {} models'.
format(lang, MODEL_URLS[DEFAULT_MODEL_VERSION][model_type].keys( format(lang, model_urls[DEFAULT_MODEL_VERSION][model_type].keys(
), model_type)) ), model_type))
sys.exit(-1) sys.exit(-1)
return MODEL_URLS[version][model_type][lang] return model_urls[version][model_type][lang]
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -265,23 +296,28 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -265,23 +296,28 @@ class PaddleOCR(predict_system.TextSystem):
""" """
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log: if not params.show_log:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls self.use_angle_cls = params.use_angle_cls
lang, det_lang = parse_lang(params.lang) lang, det_lang = parse_lang(params.lang)
# init model dir # init model dir
det_model_config = get_model_config(params.version, 'det', det_lang) det_model_config = get_model_config('OCR', params.ocr_version, 'det',
det_lang)
params.det_model_dir, det_url = confirm_model_dir_url( params.det_model_dir, det_url = confirm_model_dir_url(
params.det_model_dir, params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang), os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
det_model_config['url']) det_model_config['url'])
rec_model_config = get_model_config(params.version, 'rec', lang) rec_model_config = get_model_config('OCR', params.ocr_version, 'rec',
lang)
params.rec_model_dir, rec_url = confirm_model_dir_url( params.rec_model_dir, rec_url = confirm_model_dir_url(
params.rec_model_dir, params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang), os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
rec_model_config['url']) rec_model_config['url'])
cls_model_config = get_model_config(params.version, 'cls', 'ch') cls_model_config = get_model_config('OCR', params.ocr_version, 'cls',
'ch')
params.cls_model_dir, cls_url = confirm_model_dir_url( params.cls_model_dir, cls_url = confirm_model_dir_url(
params.cls_model_dir, params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'), os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
...@@ -362,22 +398,27 @@ class PPStructure(OCRSystem): ...@@ -362,22 +398,27 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log: if not params.show_log:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
lang, det_lang = parse_lang(params.lang) lang, det_lang = parse_lang(params.lang)
# init model dir # init model dir
det_model_config = get_model_config(params.version, 'det', det_lang) det_model_config = get_model_config('OCR', params.ocr_version, 'det',
det_lang)
params.det_model_dir, det_url = confirm_model_dir_url( params.det_model_dir, det_url = confirm_model_dir_url(
params.det_model_dir, params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang), os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
det_model_config['url']) det_model_config['url'])
rec_model_config = get_model_config(params.version, 'rec', lang) rec_model_config = get_model_config('OCR', params.ocr_version, 'rec',
lang)
params.rec_model_dir, rec_url = confirm_model_dir_url( params.rec_model_dir, rec_url = confirm_model_dir_url(
params.rec_model_dir, params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang), os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
rec_model_config['url']) rec_model_config['url'])
table_model_config = get_model_config(params.version, 'table', 'en') table_model_config = get_model_config(
'STRUCTURE', params.structure_version, 'table', 'en')
params.table_model_dir, table_url = confirm_model_dir_url( params.table_model_dir, table_url = confirm_model_dir_url(
params.table_model_dir, params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'), os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
......
...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone ...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
from .base_model import BaseModel from .base_model import BaseModel
from ppocr.utils.save_load import init_model, load_pretrained_params from ppocr.utils.save_load import load_pretrained_params
__all__ = ['DistillationModel'] __all__ = ['DistillationModel']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment