Unverified Commit 6ab12348 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2625 from opendatalab/release-2.0.0

Release 2.0.0
parents 9487d33d 4fbec469
lang:
ch_lite:
det: ch_PP-OCRv3_det_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_infer.pth
dict: ppocrv5_dict.txt
ch_lite_v4:
det: ch_PP-OCRv3_det_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv4_rec_infer.pth
dict: ppocr_keys_v1.txt
ch_server:
det: ch_PP-OCRv3_det_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth
dict: ppocrv5_dict.txt
ch_server_v4:
det: ch_PP-OCRv3_det_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv4_rec_server_infer.pth
dict: ppocr_keys_v1.txt
ch:
det: ch_PP-OCRv3_det_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv4_rec_server_doc_infer.pth
dict: ppocrv4_doc_dict.txt
en:
......@@ -28,12 +28,12 @@ lang:
rec: korean_PP-OCRv3_rec_infer.pth
dict: korean_dict.txt
japan:
det: Multilingual_PP-OCRv3_det_infer.pth
rec: japan_PP-OCRv3_rec_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth
dict: japan_dict.txt
chinese_cht:
det: Multilingual_PP-OCRv3_det_infer.pth
rec: chinese_cht_PP-OCRv3_rec_infer.pth
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth
dict: chinese_cht_dict.txt
ta:
det: Multilingual_PP-OCRv3_det_infer.pth
......
# Copyright (c) Opendatalab. All rights reserved.
......@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self.net.eval()
self.net.to(self.device)
def _batch_process_same_size(self, img_list):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime = time.time()
# 预处理所有图像
batch_data = []
batch_shapes = []
ori_imgs = []
for img in img_list:
ori_im = img.copy()
ori_imgs.append(ori_im)
data = {'image': img}
data = transform(data, self.preprocess_op)
if data is None:
# 如果预处理失败,返回空结果
return [(None, 0) for _ in img_list], 0
img_processed, shape_list = data
batch_data.append(img_processed)
batch_shapes.append(shape_list)
# 堆叠成批处理张量
try:
batch_tensor = np.stack(batch_data, axis=0)
batch_shapes = np.stack(batch_shapes, axis=0)
except Exception as e:
# 如果堆叠失败,回退到逐个处理
batch_results = []
for img in img_list:
dt_boxes, elapse = self.__call__(img)
batch_results.append((dt_boxes, elapse))
return batch_results, time.time() - starttime
# 批处理推理
with torch.no_grad():
inp = torch.from_numpy(batch_tensor)
inp = inp.to(self.device)
outputs = self.net(inp)
# 处理输出
preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs['f_geo'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs['f_border'].cpu().numpy()
preds['f_score'] = outputs['f_score'].cpu().numpy()
preds['f_tco'] = outputs['f_tco'].cpu().numpy()
preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs['maps'].cpu().numpy()
elif self.det_algorithm == 'FCE':
for i, (k, output) in enumerate(outputs.items()):
preds['level_{}'.format(i)] = output.cpu().numpy()
else:
raise NotImplementedError
# 后处理每个图像的结果
batch_results = []
total_elapse = time.time() - starttime
for i in range(len(img_list)):
# 提取单个图像的预测结果
single_preds = {}
for key, value in preds.items():
if isinstance(value, np.ndarray):
single_preds[key] = value[i:i + 1] # 保持批次维度
else:
single_preds[key] = value
# 后处理
post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
dt_boxes = post_result[0]['points']
# 过滤和裁剪检测框
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
batch_results.append((dt_boxes, total_elapse / len(img_list)))
return batch_results, total_elapse
def batch_predict(self, img_list, max_batch_size=8):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if not img_list:
return []
batch_results = []
# 分批处理
for i in range(0, len(img_list), max_batch_size):
batch_imgs = img_list[i:i + max_batch_size]
# assert尺寸一致
batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
batch_results.extend(batch_dt_boxes)
return batch_results
def order_points_clockwise(self, pts):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment