"vscode:/vscode.git/clone" did not exist on "1f6d3149edc3f2d1c5e5a3829d37153f14d1dbd4"
Commit 78d51971 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'upstream/dygraph' into dy3

parents bd314018 c683a181
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,27 +11,13 @@ ...@@ -11,27 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -*- coding: utf-8 -*- from engine.synthesisers import DatasetSynthesiser
import requests
import json
import cv2
import base64
import os, sys
import time
def cv2_to_base64(image): def synth_dataset():
#data = cv2.imencode('.jpg', image)[1] dataset_synthesiser = DatasetSynthesiser()
return base64.b64encode(image).decode( dataset_synthesiser.synth_dataset()
'utf8') #data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:9292/ocr/prediction" if __name__ == '__main__':
test_img_dir = "../../doc/imgs/" synth_dataset()
for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import sys
import glob
from utils.config import ArgsParser
from engine.synthesisers import ImageSynthesiser
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def synth_image():
args = ArgsParser().parse_args()
image_synthesiser = ImageSynthesiser()
style_image_path = args.style_image
img = cv2.imread(style_image_path)
text_corpus = args.text_corpus
language = args.language
synth_result = image_synthesiser.synth_image(text_corpus, img, language)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
cv2.imwrite("fake_fusion.jpg", fake_fusion)
cv2.imwrite("fake_text.jpg", fake_text)
cv2.imwrite("fake_bg.jpg", fake_bg)
def batch_synth_images():
image_synthesiser = ImageSynthesiser()
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
save_path = "./output_data/"
corpus_list = []
with open(corpus_file, "rb") as fin:
lines = fin.readlines()
for line in lines:
substr = line.decode("utf-8").strip("\n").split("\t")
corpus_list.append(substr)
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
corpus_num = len(corpus_list)
style_img_num = len(style_img_list)
for cno in range(corpus_num):
for sno in range(style_img_num):
corpus, lang = corpus_list[cno]
style_img_path = style_img_list[sno]
img = cv2.imread(style_img_path)
synth_result = image_synthesiser.synth_image(corpus, img, lang)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
for tp in range(2):
if tp == 0:
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
else:
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
cv2.imwrite("%s_input_style.jpg" % prefix, img)
print(cno, corpus_num, sno, style_img_num)
if __name__ == '__main__':
# batch_synth_images()
synth_image()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from argparse import ArgumentParser, RawDescriptionHelpFormatter
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
assert ks[0] in dl, (
'({}) doesn\'t exist in {}, a new dict field is invalid'.
format(ks[0], dl))
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-t", "--tag", default="0", help="tag for marking worker")
self.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
self.add_argument(
"--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
self.add_argument(
"--text_corpus", default="PaddleOCR", help="tag for marking worker")
self.add_argument(
"--language", default="en", help="tag for marking worker")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
return args
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext = os.path.splitext(file_path)[1]
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
with open(file_path, 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
def gen_config():
base_config = {
"Global": {
"algorithm": "SRNet",
"use_gpu": True,
"start_epoch": 1,
"stage1_epoch_num": 100,
"stage2_epoch_num": 100,
"log_smooth_window": 20,
"print_batch_step": 2,
"save_model_dir": "./output/SRNet",
"use_visualdl": False,
"save_epoch_step": 10,
"vgg_pretrain": "./pretrained/VGG19_pretrained",
"vgg_load_static_pretrain": True
},
"Architecture": {
"model_type": "data_aug",
"algorithm": "SRNet",
"net_g": {
"name": "srnet_net_g",
"encode_dim": 64,
"norm": "batch",
"use_dropout": False,
"init_type": "xavier",
"init_gain": 0.02,
"use_dilation": 1
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator": {
"name": "srnet_bg_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
},
"fusion_discriminator": {
"name": "srnet_fusion_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
}
},
"Loss": {
"lamb": 10,
"perceptual_lamb": 1,
"muvar_lamb": 50,
"style_lamb": 500
},
"Optimizer": {
"name": "Adam",
"learning_rate": {
"name": "lambda",
"lr": 0.0002,
"lr_decay_iters": 50
},
"beta1": 0.5,
"beta2": 0.999,
},
"Train": {
"batch_size_per_card": 8,
"num_workers_per_card": 4,
"dataset": {
"delimiter": "\t",
"data_dir": "/",
"label_file": "tmp/label.txt",
"transforms": [{
"DecodeImage": {
"to_rgb": True,
"to_np": False,
"channel_first": False
}
}, {
"NormalizeImage": {
"scale": 1. / 255.,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"order": None
}
}, {
"ToCHWImage": None
}]
}
}
}
with open("config.yml", "w") as f:
yaml.dump(base_config, f)
if __name__ == '__main__':
gen_config()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
__all__ = ['load_dygraph_pretrain']
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
if not os.path.exists(path + '.pdparams'):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(path))
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def compute_mean_covariance(img):
batch_size = img.shape[0]
channel_num = img.shape[1]
height = img.shape[2]
width = img.shape[3]
num_pixels = height * width
# batch_size * channel_num * 1 * 1
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
# batch_size * channel_num * num_pixels
img_hat = img - mu.expand_as(img)
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
# batch_size * num_pixels * channel_num
img_hat_transpose = img_hat.transpose([0, 2, 1])
# batch_size * channel_num * channel_num
covariance = paddle.bmm(img_hat, img_hat_transpose)
covariance = covariance / num_pixels
return mu, covariance
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
eps = 1e-5
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import errno
import paddle
def get_check_global_params(mode):
check_params = [
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
'character_type', 'loss_type'
]
if mode == "train_eval":
check_params = check_params + [
'train_batch_size_per_card', 'test_batch_size_per_card'
]
elif mode == "test":
check_params = check_params + ['test_batch_size_per_card']
return check_params
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
if use_gpu:
try:
if not paddle.is_compiled_with_cuda():
print(err)
sys.exit(1)
except:
print("Fail to check gpu state.")
sys.exit(1)
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
...@@ -61,8 +61,8 @@ Train: ...@@ -61,8 +61,8 @@ Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/ data_dir: ./train_data/
label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train_label_json.txt] label_file_list: [./train_data/icdar2013/train_label_json.txt, ./train_data/icdar2015/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt]
data_ratio_list: [0.5, 0.5] ratio_list: [0.1, 0.45, 0.3, 0.15]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
...@@ -60,8 +60,8 @@ Metric: ...@@ -60,8 +60,8 @@ Metric:
Train: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
label_file_list: [./train_data/icdar2013/train_label_json.txt, ./train_data/icdar2015/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt] label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train_label_json.txt]
ratio_list: [0.1, 0.45, 0.3, 0.15] data_ratio_list: [0.5, 0.5]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
...@@ -36,12 +36,13 @@ Architecture: ...@@ -36,12 +36,13 @@ Architecture:
algorithm: CRNN algorithm: CRNN
Transform: Transform:
Backbone: Backbone:
name: ResNet name: MobileNetV3
layers: 34 scale: 0.5
model_name: large
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
hidden_size: 256 hidden_size: 96
Head: Head:
name: CTCHead name: CTCHead
fc_decay: 0 fc_decay: 0
......
...@@ -12,7 +12,7 @@ def read_params(): ...@@ -12,7 +12,7 @@ def read_params():
cfg = Config() cfg = Config()
#params for text classifier #params for text classifier
cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v1.1_cls_infer/" cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v2.0_cls_infer/"
cfg.cls_image_shape = "3, 48, 192" cfg.cls_image_shape = "3, 48, 192"
cfg.label_list = ['0', '180'] cfg.label_list = ['0', '180']
cfg.cls_batch_num = 30 cfg.cls_batch_num = 30
......
...@@ -13,7 +13,7 @@ def read_params(): ...@@ -13,7 +13,7 @@ def read_params():
#params for text detector #params for text detector
cfg.det_algorithm = "DB" cfg.det_algorithm = "DB"
cfg.det_model_dir = "./inference/ch_ppocr_mobile_v1.1_det_infer/" cfg.det_model_dir = "./inference/ch_ppocr_mobile_v2.0_det_infer/"
cfg.det_limit_side_len = 960 cfg.det_limit_side_len = 960
cfg.det_limit_type = 'max' cfg.det_limit_type = 'max'
...@@ -27,16 +27,6 @@ def read_params(): ...@@ -27,16 +27,6 @@ def read_params():
# cfg.det_east_cover_thresh = 0.1 # cfg.det_east_cover_thresh = 0.1
# cfg.det_east_nms_thresh = 0.2 # cfg.det_east_nms_thresh = 0.2
# #params for text recognizer
# cfg.rec_algorithm = "CRNN"
# cfg.rec_model_dir = "./inference/ch_det_mv3_crnn/"
# cfg.rec_image_shape = "3, 32, 320"
# cfg.rec_char_type = 'ch'
# cfg.rec_batch_num = 30
# cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
# cfg.use_space_char = True
cfg.use_zero_copy_run = False cfg.use_zero_copy_run = False
cfg.use_pdserving = False cfg.use_pdserving = False
......
...@@ -13,7 +13,7 @@ def read_params(): ...@@ -13,7 +13,7 @@ def read_params():
#params for text detector #params for text detector
cfg.det_algorithm = "DB" cfg.det_algorithm = "DB"
cfg.det_model_dir = "./inference/ch_ppocr_mobile_v1.1_det_infer/" cfg.det_model_dir = "./inference/ch_ppocr_mobile_v2.0_det_infer/"
cfg.det_limit_side_len = 960 cfg.det_limit_side_len = 960
cfg.det_limit_type = 'max' cfg.det_limit_type = 'max'
...@@ -29,7 +29,7 @@ def read_params(): ...@@ -29,7 +29,7 @@ def read_params():
#params for text recognizer #params for text recognizer
cfg.rec_algorithm = "CRNN" cfg.rec_algorithm = "CRNN"
cfg.rec_model_dir = "./inference/ch_ppocr_mobile_v1.1_rec_infer/" cfg.rec_model_dir = "./inference/ch_ppocr_mobile_v2.0_rec_infer/"
cfg.rec_image_shape = "3, 32, 320" cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch' cfg.rec_char_type = 'ch'
...@@ -41,7 +41,7 @@ def read_params(): ...@@ -41,7 +41,7 @@ def read_params():
#params for text classifier #params for text classifier
cfg.use_angle_cls = True cfg.use_angle_cls = True
cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v1.1_cls_infer/" cfg.cls_model_dir = "./inference/ch_ppocr_mobile_v2.0_cls_infer/"
cfg.cls_image_shape = "3, 48, 192" cfg.cls_image_shape = "3, 48, 192"
cfg.label_list = ['0', '180'] cfg.label_list = ['0', '180']
cfg.cls_batch_num = 30 cfg.cls_batch_num = 30
...@@ -49,5 +49,6 @@ def read_params(): ...@@ -49,5 +49,6 @@ def read_params():
cfg.use_zero_copy_run = False cfg.use_zero_copy_run = False
cfg.use_pdserving = False cfg.use_pdserving = False
cfg.drop_score = 0.5
return cfg return cfg
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
PaddleOCR提供2种服务部署方式: PaddleOCR提供2种服务部署方式:
- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",按照本教程使用; - 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",按照本教程使用;
- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/readme.md) - (coming soon)基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/readme.md)
# 基于PaddleHub Serving的服务部署 # 基于PaddleHub Serving的服务部署
...@@ -33,11 +33,11 @@ pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple ...@@ -33,11 +33,11 @@ pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
``` ```
### 2. 下载推理模型 ### 2. 下载推理模型
安装服务模块前,需要准备推理模型并放到正确路径。默认使用的是v1.1版的超轻量模型,默认模型路径为: 安装服务模块前,需要准备推理模型并放到正确路径。默认使用的是v2.0版的超轻量模型,默认模型路径为:
``` ```
检测模型:./inference/ch_ppocr_mobile_v1.1_det_infer/ 检测模型:./inference/ch_ppocr_mobile_v2.0_det_infer/
识别模型:./inference/ch_ppocr_mobile_v1.1_rec_infer/ 识别模型:./inference/ch_ppocr_mobile_v2.0_rec_infer/
方向分类器:./inference/ch_ppocr_mobile_v1.1_cls_infer/ 方向分类器:./inference/ch_ppocr_mobile_v2.0_cls_infer/
``` ```
**模型路径可在`params.py`中查看和修改。** 更多模型可以从PaddleOCR提供的[模型库](../../doc/doc_ch/models_list.md)下载,也可以替换成自己训练转换好的模型。 **模型路径可在`params.py`中查看和修改。** 更多模型可以从PaddleOCR提供的[模型库](../../doc/doc_ch/models_list.md)下载,也可以替换成自己训练转换好的模型。
......
...@@ -2,7 +2,7 @@ English | [简体中文](readme.md) ...@@ -2,7 +2,7 @@ English | [简体中文](readme.md)
PaddleOCR provides 2 service deployment methods: PaddleOCR provides 2 service deployment methods:
- Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please follow this tutorial. - Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please follow this tutorial.
- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/readme.md) for usage. - (coming soon)Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/readme.md) for usage.
# Service deployment based on PaddleHub Serving # Service deployment based on PaddleHub Serving
...@@ -34,11 +34,11 @@ pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple ...@@ -34,11 +34,11 @@ pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
``` ```
### 2. Download inference model ### 2. Download inference model
Before installing the service module, you need to prepare the inference model and put it in the correct path. By default, the ultra lightweight model of v1.1 is used, and the default model path is: Before installing the service module, you need to prepare the inference model and put it in the correct path. By default, the ultra lightweight model of v2.0 is used, and the default model path is:
``` ```
detection model: ./inference/ch_ppocr_mobile_v1.1_det_infer/ detection model: ./inference/ch_ppocr_mobile_v2.0_det_infer/
recognition model: ./inference/ch_ppocr_mobile_v1.1_rec_infer/ recognition model: ./inference/ch_ppocr_mobile_v2.0_rec_infer/
text direction classifier: ./inference/ch_ppocr_mobile_v1.1_cls_infer/ text direction classifier: ./inference/ch_ppocr_mobile_v2.0_cls_infer/
``` ```
**The model path can be found and modified in `params.py`.** More models provided by PaddleOCR can be obtained from the [model library](../../doc/doc_en/models_list_en.md). You can also use models trained by yourself. **The model path can be found and modified in `params.py`.** More models provided by PaddleOCR can be obtained from the [model library](../../doc/doc_en/models_list_en.md). You can also use models trained by yourself.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_det(self):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"]
def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
return {"dt_boxes": dt_boxes.tolist()}
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_det_model")
ocr_service.init_det()
if sys.argv[1] == 'gpu':
ocr_service.set_gpus("0")
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.run_debugger_service(gpu=True)
elif sys.argv[1] == 'cpu':
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.run_debugger_service()
ocr_service.init_det()
ocr_service.run_web_service()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment