Commit df001f3c authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 9cce1213 bdca6cd7
...@@ -242,3 +242,7 @@ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Archi ...@@ -242,3 +242,7 @@ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Archi
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm - microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
- XFUND dataset, https://github.com/doc-analysis/XFUND - XFUND dataset, https://github.com/doc-analysis/XFUND
## License
The content of this project itself is licensed under the [Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
...@@ -12,4 +12,3 @@ cython ...@@ -12,4 +12,3 @@ cython
lxml lxml
premailer premailer
openpyxl openpyxl
fasttext==0.9.1
...@@ -56,7 +56,7 @@ PostProcess: ...@@ -56,7 +56,7 @@ PostProcess:
thresh: 0 thresh: 0
box_thresh: 0.85 box_thresh: 0.85
min_area: 16 min_area: 16
box_type: box # 'box' or 'poly' box_type: quad # 'quad' or 'poly'
scale: 1 scale: 1
Metric: Metric:
......
...@@ -55,7 +55,7 @@ PostProcess: ...@@ -55,7 +55,7 @@ PostProcess:
thresh: 0 thresh: 0
box_thresh: 0.85 box_thresh: 0.85
min_area: 16 min_area: 16
box_type: box # 'box' or 'poly' box_type: quad # 'quad' or 'poly'
scale: 1 scale: 1
Metric: Metric:
......
...@@ -60,6 +60,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -60,6 +60,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
ln -s ./icdar2015_lite ./icdar2015 ln -s ./icdar2015_lite ./icdar2015
cd ../ cd ../
cd ./inference && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf rec_inference.tar && cd ../
if [ ${model_name} == "ch_PPOCRv2_det" ] || [ ${model_name} == "ch_PPOCRv2_det_PACT" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../
fi
if [ ${model_name} == "det_r18_db_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
fi
if [ ${model_name} == "en_server_pgnetA" ]; then if [ ${model_name} == "en_server_pgnetA" ]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
......
...@@ -122,7 +122,7 @@ def preprocess(is_train=False): ...@@ -122,7 +122,7 @@ def preprocess(is_train=False):
log_file = '{}/train.log'.format(save_model_dir) log_file = '{}/train.log'.format(save_model_dir)
else: else:
log_file = None log_file = None
logger = get_logger(name='root', log_file=log_file) logger = get_logger(log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['use_gpu'] use_gpu = config['use_gpu']
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import json
import os
def poly_to_string(poly):
if len(poly.shape) > 1:
poly = np.array(poly).flatten()
string = "\t".join(str(i) for i in poly)
return string
def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
if not os.path.exists(label_dir):
raise ValueError(f"The file {label_dir} does not exist!")
assert label_dir != save_dir, "hahahhaha"
label_file = open(label_dir, 'r')
data = label_file.readlines()
gt_dict = {}
for line in data:
try:
tmp = line.split('\t')
assert len(tmp) == 2, ""
except:
tmp = line.strip().split(' ')
gt_lists = []
if tmp[0].split('/')[0] is not None:
img_path = tmp[0]
anno = json.loads(tmp[1])
gt_collect = []
for dic in anno:
#txt = dic['transcription'].replace(' ', '') # ignore blank
txt = dic['transcription']
if 'score' in dic and float(dic['score']) < 0.5:
continue
if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
#while ' ' in txt:
# txt = txt.replace(' ', '')
poly = np.array(dic['points']).flatten()
if txt == "###":
txt_tag = 1 ## ignore 1
else:
txt_tag = 0
if mode == "gt":
gt_label = poly_to_string(poly) + "\t" + str(
txt_tag) + "\t" + txt + "\n"
else:
gt_label = poly_to_string(poly) + "\t" + txt + "\n"
gt_lists.append(gt_label)
gt_dict[img_path] = gt_lists
else:
continue
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for img_name in gt_dict.keys():
save_name = img_name.split("/")[-1]
save_file = os.path.join(save_dir, save_name + ".txt")
with open(save_file, "w") as f:
f.writelines(gt_dict[img_name])
print("The convert label saved in {}".format(save_dir))
if __name__ == "__main__":
ppocr_label_gt = "/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt"
convert_label(ppocr_label_gt, "gt", "./save_gt_310_V2/")
ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt"
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
def str2bool(v):
return v.lower() in ("true", "t", "1")
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument("--image_dir", type=str, default="")
parser.add_argument("--save_html_path", type=str, default="./default.html")
parser.add_argument("--width", type=int, default=640)
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def draw_debug_img(args):
html_path = args.save_html_path
err_cnt = 0
with open(html_path, 'w') as html:
html.write('<html>\n<body>\n')
html.write('<table border="1">\n')
html.write(
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
)
image_list = []
path = args.image_dir
for i, filename in enumerate(sorted(os.listdir(path))):
if filename.endswith("txt"): continue
# The image path
base = "{}/{}".format(path, filename)
html.write("<tr>\n")
html.write(f'<td> {filename}\n GT')
html.write(f'<td>GT\n<img src="{base}" width={args.width}></td>')
html.write("</tr>\n")
html.write('<style>\n')
html.write('span {\n')
html.write(' color: red;\n')
html.write('}\n')
html.write('</style>\n')
html.write('</table>\n')
html.write('</html>\n</body>\n')
print(f"The html file saved in {html_path}")
return
if __name__ == "__main__":
args = parse_args()
draw_debug_img(args)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import sys
import shapely
from shapely.geometry import Polygon
import numpy as np
from collections import defaultdict
import operator
import editdistance
def strQ2B(ustring):
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 12288:
inside_code = 32
elif (inside_code >= 65281 and inside_code <= 65374):
inside_code -= 65248
rstring += chr(inside_code)
return rstring
def polygon_from_str(polygon_points):
"""
Create a shapely polygon object from gt or dt line.
"""
polygon_points = np.array(polygon_points).reshape(4, 2)
polygon = Polygon(polygon_points).convex_hull
return polygon
def polygon_iou(poly1, poly2):
"""
Intersection over union between two shapely polygons.
"""
if not poly1.intersects(
poly2): # this test is fast and can accelerate calculation
iou = 0
else:
try:
inter_area = poly1.intersection(poly2).area
union_area = poly1.area + poly2.area - inter_area
iou = float(inter_area) / union_area
except shapely.geos.TopologicalError:
# except Exception as e:
# print(e)
print('shapely.geos.TopologicalError occured, iou set to 0')
iou = 0
return iou
def ed(str1, str2):
return editdistance.eval(str1, str2)
def e2e_eval(gt_dir, res_dir, ignore_blank=False):
print('start testing...')
iou_thresh = 0.5
val_names = os.listdir(gt_dir)
num_gt_chars = 0
gt_count = 0
dt_count = 0
hit = 0
ed_sum = 0
for i, val_name in enumerate(val_names):
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
gt_lines = [o.strip() for o in f.readlines()]
gts = []
ignore_masks = []
for line in gt_lines:
parts = line.strip().split('\t')
# ignore illegal data
if len(parts) < 9:
continue
assert (len(parts) < 11)
if len(parts) == 9:
gts.append(parts[:8] + [''])
else:
gts.append(parts[:8] + [parts[-1]])
ignore_masks.append(parts[8])
val_path = os.path.join(res_dir, val_name)
if not os.path.exists(val_path):
dt_lines = []
else:
with open(val_path, encoding='utf-8') as f:
dt_lines = [o.strip() for o in f.readlines()]
dts = []
for line in dt_lines:
# print(line)
parts = line.strip().split("\t")
assert (len(parts) < 10), "line error: {}".format(line)
if len(parts) == 8:
dts.append(parts + [''])
else:
dts.append(parts)
dt_match = [False] * len(dts)
gt_match = [False] * len(gts)
all_ious = defaultdict(tuple)
for index_gt, gt in enumerate(gts):
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
gt_poly = polygon_from_str(gt_coors)
for index_dt, dt in enumerate(dts):
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
dt_poly = polygon_from_str(dt_coors)
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
sorted_ious = sorted(
all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
# matched gt and dt
for gt_dt_pair in sorted_gt_dt_pairs:
index_gt, index_dt = gt_dt_pair
if gt_match[index_gt] == False and dt_match[index_dt] == False:
gt_match[index_gt] = True
dt_match[index_dt] = True
if ignore_blank:
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
else:
gt_str = strQ2B(gts[index_gt][8])
dt_str = strQ2B(dts[index_dt][8])
if ignore_masks[index_gt] == '0':
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
if gt_str == dt_str:
hit += 1
gt_count += 1
dt_count += 1
# unmatched dt
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_str = dts[tindex][8]
gt_str = ''
ed_sum += ed(dt_str, gt_str)
dt_count += 1
# unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
if gt_match_flag == False and ignore_masks[tindex] == '0':
dt_str = ''
gt_str = gts[tindex][8]
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
gt_count += 1
eps = 1e-9
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
precision = hit / (dt_count + eps)
recall = hit / (gt_count + eps)
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
avg_edit_dist_img = ed_sum / len(val_names)
avg_edit_dist_field = ed_sum / (gt_count + eps)
character_acc = 1 - ed_sum / (num_gt_chars + eps)
print('character_acc: %.2f' % (character_acc * 100) + "%")
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
print('precision: %.2f' % (precision * 100) + "%")
print('recall: %.2f' % (recall * 100) + "%")
print('fmeasure: %.2f' % (fmeasure * 100) + "%")
if __name__ == '__main__':
# if len(sys.argv) != 3:
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
# exit(-1)
# gt_folder = sys.argv[1]
# pred_folder = sys.argv[2]
gt_folder = sys.argv[1]
pred_folder = sys.argv[2]
e2e_eval(gt_folder, pred_folder)
# 简介
`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
## 端对端评测步骤
**步骤一:**
运行`tools/infer/predict_system.py`,得到保存的结果:
```
python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
```
文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/system_results.txt`中,格式如下:
```
all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
```
**步骤二:**
将步骤一保存的数据转换为端对端评测需要的数据格式:
修改 `tools/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
```
ppocr_label_gt = "gt_label.txt"
convert_label(ppocr_label_gt, "gt", "./save_gt_label/")
ppocr_label_gt = "./ch_PP-OCRv2_results/system_results.txt"
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
```
运行`convert_ppocr_label.py`:
```
python3 tools/convert_ppocr_label.py
```
得到如下结果:
```
├── ./save_gt_label/
├── ./save_PPOCRV2_infer/
```
**步骤三:**
执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下:
```
python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
```
比如:
```
python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
```
将得到如下结果,fmeasure为主要关注的指标:
```
hit, dt_count, gt_count 1557 2693 3283
character_acc: 61.77%
avg_edit_dist_field: 3.08
avg_edit_dist_img: 51.82
precision: 57.82%
recall: 47.43%
fmeasure: 52.11%
```
...@@ -150,27 +150,13 @@ class TextDetector(object): ...@@ -150,27 +150,13 @@ class TextDetector(object):
logger=logger) logger=logger)
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" rect = np.zeros((4, 2), dtype="float32")
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py s = pts.sum(axis=1)
# sort the points based on their x-coordinates rect[0] = pts[np.argmin(s)]
""" rect[2] = pts[np.argmax(s)]
xSorted = pts[np.argsort(pts[:, 0]), :] diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
# grab the left-most and right-most points from the sorted rect[3] = pts[np.argmax(diff)]
# x-roodinate points
leftMost = xSorted[:2, :]
rightMost = xSorted[2:, :]
# now, sort the left-most coordinates according to their
# y-coordinates so we can grab the top-left and bottom-left
# points, respectively
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
(tl, bl) = leftMost
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
(tr, br) = rightMost
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect return rect
def clip_det_res(self, points, img_height, img_width): def clip_det_res(self, points, img_height, img_width):
......
...@@ -622,7 +622,6 @@ def get_rotate_crop_image(img, points): ...@@ -622,7 +622,6 @@ def get_rotate_crop_image(img, points):
def check_gpu(use_gpu): def check_gpu(use_gpu):
if use_gpu and not paddle.is_compiled_with_cuda(): if use_gpu and not paddle.is_compiled_with_cuda():
use_gpu = False use_gpu = False
return use_gpu return use_gpu
......
...@@ -151,7 +151,7 @@ def preprocess(): ...@@ -151,7 +151,7 @@ def preprocess():
ser_config = load_config(FLAGS.config_ser) ser_config = load_config(FLAGS.config_ser)
ser_config = merge_config(ser_config, FLAGS.opt_ser) ser_config = merge_config(ser_config, FLAGS.opt_ser)
logger = get_logger(name='root') logger = get_logger()
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
......
...@@ -525,7 +525,7 @@ def preprocess(is_train=False): ...@@ -525,7 +525,7 @@ def preprocess(is_train=False):
log_file = '{}/train.log'.format(save_model_dir) log_file = '{}/train.log'.format(save_model_dir)
else: else:
log_file = None log_file = None
logger = get_logger(name='root', log_file=log_file) logger = get_logger(log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
......
...@@ -25,7 +25,9 @@ import numpy as np ...@@ -25,7 +25,9 @@ import numpy as np
import time import time
from PIL import Image from PIL import Image
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
from tools.infer.utility import draw_ocr, draw_boxes from tools.infer.utility import draw_ocr, draw_boxes, str2bool
from ppstructure.utility import draw_structure_result
from ppstructure.predict_system import to_excel
import requests import requests
import json import json
...@@ -69,8 +71,33 @@ def draw_server_result(image_file, res): ...@@ -69,8 +71,33 @@ def draw_server_result(image_file, res):
return draw_img return draw_img
def main(url, image_path): def save_structure_res(res, save_folder, image_file):
image_file_list = get_image_file_list(image_path) img = cv2.imread(image_file)
excel_save_folder = os.path.join(save_folder, os.path.basename(image_file))
os.makedirs(excel_save_folder, exist_ok=True)
# save res
with open(
os.path.join(excel_save_folder, 'res.txt'), 'w',
encoding='utf8') as f:
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
x1, y1, x2, y2 = region['bbox']
print(region['bbox'])
roi_img = img[y1:y2, x1:x2, :]
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
else:
for text_result in region['res']:
f.write('{}\n'.format(json.dumps(text_result)))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
is_visualize = False is_visualize = False
headers = {"Content-type": "application/json"} headers = {"Content-type": "application/json"}
cnt = 0 cnt = 0
...@@ -80,38 +107,51 @@ def main(url, image_path): ...@@ -80,38 +107,51 @@ def main(url, image_path):
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
img_name = os.path.basename(image_file)
# 发送HTTP请求 # seed http request
starttime = time.time() starttime = time.time()
data = {'images': [cv2_to_base64(img)]} data = {'images': [cv2_to_base64(img)]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(
url=args.server_url, headers=headers, data=json.dumps(data))
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
res = r.json()["results"][0] res = r.json()["results"][0]
logger.info(res) logger.info(res)
if is_visualize: if args.visualize:
draw_img = None
if 'structure_table' in args.server_url:
to_excel(res['html'], './{}.xlsx'.format(img_name))
elif 'structure_system' in args.server_url:
save_structure_res(res['regions'], args.output, image_file)
else:
draw_img = draw_server_result(image_file, res) draw_img = draw_server_result(image_file, res)
if draw_img is not None: if draw_img is not None:
draw_img_save = "./server_results/" if not os.path.exists(args.output):
if not os.path.exists(draw_img_save): os.makedirs(args.output)
os.makedirs(draw_img_save)
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(args.output, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format( logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(args.output, os.path.basename(image_file))))
cnt += 1 cnt += 1
if cnt % 100 == 0: if cnt % 100 == 0:
logger.info("{} processed".format(cnt)) logger.info("{} processed".format(cnt))
logger.info("avg time cost: {}".format(float(total_time) / cnt)) logger.info("avg time cost: {}".format(float(total_time) / cnt))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="args for hub serving")
parser.add_argument("--server_url", type=str, required=True)
parser.add_argument("--image_dir", type=str, required=True)
parser.add_argument("--visualize", type=str2bool, default=False)
parser.add_argument("--output", type=str, default='./hubserving_result')
args = parser.parse_args()
return args
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 3: args = parse_args()
logger.info("Usage: %s server_url image_path" % sys.argv[0]) main(args)
else:
server_url = sys.argv[1]
image_path = sys.argv[2]
main(server_url, image_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment