Commit fe137242 authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into tree_doc

parents 53d4eab6 b1623d69
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_rec(self):
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
img_list = []
for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im)
feed_list = []
max_wh_ratio = 0
for i, boximg in enumerate(img_list):
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": res_lst}
return res
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
ocr_service.init_rec()
if sys.argv[1] == 'gpu':
ocr_service.set_gpus("0")
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
elif sys.argv[1] == 'cpu':
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
...@@ -37,8 +37,6 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset ...@@ -37,8 +37,6 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。 若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
<a name="自定义数据集"></a> <a name="自定义数据集"></a>
* 使用自己数据集 * 使用自己数据集
...@@ -65,7 +63,7 @@ wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_t ...@@ -65,7 +63,7 @@ wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_t
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
``` ```
PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支持的数据格式。 数据转换工具在 `train_data/gen_label.py`, 这里以训练集为例: PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支持的数据格式。 数据转换工具在 `ppocr/utils/gen_label.py`, 这里以训练集为例:
``` ```
# 将官网下载的标签文件转换为 rec_gt_label.txt # 将官网下载的标签文件转换为 rec_gt_label.txt
...@@ -116,9 +114,9 @@ n ...@@ -116,9 +114,9 @@ n
word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“and” 将被映射成 [2 5 1] word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“and” 将被映射成 [2 5 1]
`ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典 `ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典
`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典 `ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典
`ppocr/utils/dict/french_dict.txt` 是一个包含118个字符的法文字典 `ppocr/utils/dict/french_dict.txt` 是一个包含118个字符的法文字典
...@@ -128,6 +126,8 @@ word_dict.txt 每行有一个单字,将字符与数字索引映射在一起, ...@@ -128,6 +126,8 @@ word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,
`ppocr/utils/dict/german_dict.txt` 是一个包含131个字符的德文字典 `ppocr/utils/dict/german_dict.txt` 是一个包含131个字符的德文字典
`ppocr/utils/dict/en_dict.txt` 是一个包含63个字符的英文字典
您可以按需使用。 您可以按需使用。
...@@ -155,10 +155,10 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 CRNN ...@@ -155,10 +155,10 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 CRNN
``` ```
cd PaddleOCR/ cd PaddleOCR/
# 下载MobileNetV3的预训练模型 # 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# 解压模型参数 # 解压模型参数
cd pretrain_models cd pretrain_models
tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc_v2.0_train.tar
``` ```
开始训练: 开始训练:
...@@ -204,9 +204,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t ...@@ -204,9 +204,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention | | rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r34_vd_tps_bilstm_attn.yml | RARE | Resnet34_vd | tps | BiLSTM | attention |
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc | | rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
......
...@@ -120,6 +120,9 @@ In `word_dict.txt`, there is a single word in each line, which maps characters a ...@@ -120,6 +120,9 @@ In `word_dict.txt`, there is a single word in each line, which maps characters a
`ppocr/utils/dict/german_dict.txt` is a German dictionary with 131 characters `ppocr/utils/dict/german_dict.txt` is a German dictionary with 131 characters
`ppocr/utils/dict/en_dict.txt` is a English dictionary with 63 characters
You can use it on demand. You can use it on demand.
The current multi-language model is still in the demo stage and will continue to optimize the model and add languages. **You are very welcome to provide us with dictionaries and fonts in other languages**, The current multi-language model is still in the demo stage and will continue to optimize the model and add languages. **You are very welcome to provide us with dictionaries and fonts in other languages**,
...@@ -149,10 +152,10 @@ First download the pretrain model, you can download the trained model to finetun ...@@ -149,10 +152,10 @@ First download the pretrain model, you can download the trained model to finetun
``` ```
cd PaddleOCR/ cd PaddleOCR/
# Download the pre-trained model of MobileNetV3 # Download the pre-trained model of MobileNetV3
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# Decompress model parameters # Decompress model parameters
cd pretrain_models cd pretrain_models
tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc_v2.0_train.tar
``` ```
Start training: Start training:
...@@ -194,7 +197,6 @@ If the evaluation set is large, the test will be time-consuming. It is recommend ...@@ -194,7 +197,6 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention | | rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r34_vd_tps_bilstm_attn.yml | RARE | Resnet34_vd | tps | BiLSTM | attention |
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc | | rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
For training Chinese data, it is recommended to use For training Chinese data, it is recommended to use
......
doc/joinus.PNG

15.7 KB | W: | H:

doc/joinus.PNG

408 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -67,6 +67,7 @@ def build_dataloader(config, mode, device, logger): ...@@ -67,6 +67,7 @@ def build_dataloader(config, mode, device, logger):
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']
num_workers = loader_config['num_workers'] num_workers = loader_config['num_workers']
use_shared_memory = False
if mode == "Train": if mode == "Train":
#Distribute data to multiple cards #Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
...@@ -74,6 +75,7 @@ def build_dataloader(config, mode, device, logger): ...@@ -74,6 +75,7 @@ def build_dataloader(config, mode, device, logger):
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=drop_last) drop_last=drop_last)
use_shared_memory = True
else: else:
#Distribute data to single card #Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
...@@ -87,6 +89,7 @@ def build_dataloader(config, mode, device, logger): ...@@ -87,6 +89,7 @@ def build_dataloader(config, mode, device, logger):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True,
use_shared_memory=use_shared_memory)
return data_loader return data_loader
...@@ -35,12 +35,13 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort ...@@ -35,12 +35,13 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort
class RecAug(object): class RecAug(object):
def __init__(self, use_tia=True, **kwargsz): def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
self.use_tia = use_tia self.use_tia = use_tia
self.aug_prob = aug_prob
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
img = warp(img, 10, self.use_tia) img = warp(img, 10, self.use_tia, self.aug_prob)
data['image'] = img data['image'] = img
return data return data
...@@ -329,7 +330,7 @@ def get_warpAffine(config): ...@@ -329,7 +330,7 @@ def get_warpAffine(config):
return rz return rz
def warp(img, ang, use_tia=True): def warp(img, ang, use_tia=True, prob=0.4):
""" """
warp warp
""" """
...@@ -338,8 +339,6 @@ def warp(img, ang, use_tia=True): ...@@ -338,8 +339,6 @@ def warp(img, ang, use_tia=True):
config.make(w, h, ang) config.make(w, h, ang)
new_img = img new_img = img
prob = 0.4
if config.distort: if config.distort:
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20: if random.random() <= prob and img_height >= 20 and img_width >= 20:
......
...@@ -40,7 +40,7 @@ class DBPostProcess(object): ...@@ -40,7 +40,7 @@ class DBPostProcess(object):
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.dilation_kernel = None if not use_dilation else [[1, 1], [1, 1]] self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
......
...@@ -33,4 +33,4 @@ v ...@@ -33,4 +33,4 @@ v
w w
x x
y y
z z
\ No newline at end of file
...@@ -63,6 +63,7 @@ class TextDetector(object): ...@@ -63,6 +63,7 @@ class TextDetector(object):
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -111,7 +112,7 @@ class TextDetector(object): ...@@ -111,7 +112,7 @@ class TextDetector(object):
box = self.clip_det_res(box, img_height, img_width) box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1])) rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3])) rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10: if rect_width <= 3 or rect_height <= 3:
continue continue
dt_boxes_new.append(box) dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new) dt_boxes = np.array(dt_boxes_new)
......
...@@ -17,8 +17,9 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -17,8 +17,9 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.utils.utility import initial_logger from ppocr.utils.logging import get_logger
logger = initial_logger() logger = get_logger()
import cv2 import cv2
import numpy as np import numpy as np
import time import time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment