Unverified Commit 14fce808 authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #3376 from WenmuZhou/fx_pse

add pse
parents 82daba23 ac98415b
import numpy as np
import cv2
cimport numpy as np
cimport cython
cimport libcpp
cimport libcpp.pair
cimport libcpp.queue
from libcpp.pair cimport *
from libcpp.queue cimport *
@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
np.ndarray[np.int32_t, ndim=2] label,
int kernel_num,
int label_num,
float min_area=0):
cdef np.ndarray[np.int32_t, ndim=2] pred
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
for label_idx in range(1, label_num):
if np.sum(label == label_idx) < min_area:
label[label == label_idx] = 0
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef np.int16_t* dx = [-1, 1, 0, 0]
cdef np.int16_t* dy = [0, 0, -1, 1]
cdef np.int16_t tmpx, tmpy
points = np.array(np.where(label > 0)).transpose((1, 0))
for point_idx in range(points.shape[0]):
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = label[tmpx, tmpy]
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
cdef int cur_label
for kernel_idx in range(kernel_num - 1, -1, -1):
while not que.empty():
cur = que.front()
que.pop()
cur_label = pred[cur.first, cur.second]
is_edge = True
for j in range(4):
tmpx = cur.first + dx[j]
tmpy = cur.second + dy[j]
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
continue
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
continue
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = cur_label
is_edge = False
if is_edge:
nxt_que.push(cur)
que, nxt_que = nxt_que, que
return pred
def pse(kernels, min_area):
kernel_num = kernels.shape[0]
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
\ No newline at end of file
from distutils.core import setup, Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules=cythonize(Extension(
'pse',
sources=['pse.pyx'],
language='c++',
include_dirs=[numpy.get_include()],
library_dirs=[],
libraries=[],
extra_compile_args=['-O3'],
extra_link_args=[]
)))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import paddle
from paddle.nn import functional as F
from ppocr.postprocess.pse_postprocess.pse import pse
class PSEPostProcess(object):
"""
The post process for PSE.
"""
def __init__(self,
thresh=0.5,
box_thresh=0.85,
min_area=16,
box_type='box',
scale=4,
**kwargs):
assert box_type in ['box', 'poly'], 'Only box and poly is supported'
self.thresh = thresh
self.box_thresh = box_thresh
self.min_area = min_area
self.box_type = box_type
self.scale = scale
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, paddle.Tensor):
pred = paddle.to_tensor(pred)
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear')
score = F.sigmoid(pred[:, 0, :, :])
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()
kernels = kernels.numpy().astype(np.uint8)
boxes_batch = []
for batch_index in range(pred.shape[0]):
boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index])
boxes_batch.append({'points': boxes, 'scores': scores})
return boxes_batch
def boxes_from_bitmap(self, score, kernels, shape):
label = pse(kernels, self.min_area)
return self.generate_box(score, label, shape)
def generate_box(self, score, label, shape):
src_h, src_w, ratio_h, ratio_w = shape
label_num = np.max(label) + 1
boxes = []
scores = []
for i in range(1, label_num):
ind = label == i
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
if points.shape[0] < self.min_area:
label[ind] = 0
continue
score_i = np.mean(score[ind])
if score_i < self.box_thresh:
label[ind] = 0
continue
if self.box_type == 'box':
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
elif self.box_type == 'poly':
box_height = np.max(points[:, 1]) + 10
box_width = np.max(points[:, 0]) + 10
mask = np.zeros((box_height, box_width), np.uint8)
mask[points[:, 1], points[:, 0]] = 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bbox = np.squeeze(contours[0], 1)
else:
raise NotImplementedError
bbox[:, 0] = np.clip(
np.round(bbox[:, 0] / ratio_w), 0, src_w)
bbox[:, 1] = np.clip(
np.round(bbox[:, 1] / ratio_h), 0, src_h)
boxes.append(bbox)
scores.append(score_i)
return boxes, scores
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
EPS = 1e-6
def iou_single(a, b, mask, n_class):
valid = mask == 1
a = a.masked_select(valid)
b = b.masked_select(valid)
miou = []
for i in range(n_class):
if a.shape == [0] and a.shape==b.shape:
inter = paddle.to_tensor(0.0)
union = paddle.to_tensor(0.0)
else:
inter = ((a == i).logical_and(b == i)).astype('float32')
union = ((a == i).logical_or(b == i)).astype('float32')
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
miou = sum(miou) / len(miou)
return miou
def iou(a, b, mask, n_class=2, reduce=True):
batch_size = a.shape[0]
a = a.reshape([batch_size, -1])
b = b.reshape([batch_size, -1])
mask = mask.reshape([batch_size, -1])
iou = paddle.zeros((batch_size,), dtype='float32')
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
if reduce:
iou = paddle.mean(iou)
return iou
\ No newline at end of file
...@@ -89,6 +89,14 @@ class TextDetector(object): ...@@ -89,6 +89,14 @@ class TextDetector(object):
postprocess_params["sample_pts_num"] = 2 postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0 postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3 postprocess_params["shrink_ratio_of_width"] = 0.3
elif self.det_algorithm == "PSE":
postprocess_params['name'] = 'PSEPostProcess'
postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -209,7 +217,7 @@ class TextDetector(object): ...@@ -209,7 +217,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1] preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2] preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm == 'DB': elif self.det_algorithm in ['DB', 'PSE']:
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -217,7 +225,9 @@ class TextDetector(object): ...@@ -217,7 +225,9 @@ class TextDetector(object):
#self.predictor.try_shrink_memory() #self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm == "PSE" and
self.det_pse_box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else: else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
...@@ -63,6 +63,13 @@ def init_args(): ...@@ -63,6 +63,13 @@ def init_args():
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=str2bool, default=False) parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
# PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='box')
parser.add_argument("--det_pse_scale", type=int, default=1)
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
......
...@@ -402,7 +402,7 @@ def preprocess(is_train=False): ...@@ -402,7 +402,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR' 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment