Commit 2401626a authored by dyning's avatar dyning
Browse files

Merge branch 'develop' of https://github.com/MissPenguin/PaddleOCR into develop

parents 160bb06e b0171a71
......@@ -125,8 +125,8 @@ class DBProcessTest(object):
def __init__(self, params):
super(DBProcessTest, self).__init__()
self.resize_type = 0
if 'det_image_shape' in params:
self.image_shape = params['det_image_shape']
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
......
......@@ -455,17 +455,23 @@ class EASTProcessTrain(object):
class EASTProcessTest(object):
def __init__(self, params):
super(EASTProcessTest, self).__init__()
self.resize_type = 0
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
else:
self.max_side_len = 2400
def resize_image(self, im):
def resize_image_type0(self, im):
"""
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
max_side_len = self.max_side_len
h, w, _ = im.shape
......@@ -495,13 +501,30 @@ class EASTProcessTest(object):
resize_w = 32
else:
resize_w = (resize_w // 32 - 1) * 32
im = cv2.resize(im, (int(resize_w), int(resize_h)))
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
except:
print(im.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_type1(self, im):
resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
return im, (ratio_h, ratio_w)
def __call__(self, im):
im, (ratio_h, ratio_w) = self.resize_image(im)
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im[:, :, ::-1].astype(np.float32)
......
......@@ -18,6 +18,7 @@ from __future__ import print_function
import paddle.fluid as fluid
from ..common_functions import conv_bn_layer, deconv_bn_layer
from collections import OrderedDict
class EASTHead(object):
......@@ -110,7 +111,7 @@ class EASTHead(object):
def __call__(self, inputs):
f_common = self.unet_fusion(inputs)
f_score, f_geo = self.detector_header(f_common)
predicts = {}
predicts = OrderedDict()
predicts['f_score'] = f_score
predicts['f_geo'] = f_geo
return predicts
......@@ -20,6 +20,13 @@ import numpy as np
from .locality_aware_nms import nms_locality
import cv2
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import lanms
class EASTPostPocess(object):
"""
......@@ -66,7 +73,8 @@ class EASTPostPocess(object):
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
# boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
if boxes.shape[0] == 0:
return []
# Here we filter some low score boxes by the average score map,
......
#!/usr/bin/env python
#
# Copyright (C) 2014 Google Inc.
#
# This file is part of YouCompleteMe.
#
# YouCompleteMe is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# YouCompleteMe is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with YouCompleteMe. If not, see <http://www.gnu.org/licenses/>.
import os
import sys
import glob
import ycm_core
# These are the compilation flags that will be used in case there's no
# compilation database set (by default, one is not set).
# CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR.
sys.path.append(os.path.dirname(__file__))
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
from plumbum.cmd import python_config
flags = [
'-Wall',
'-Wextra',
'-Wnon-virtual-dtor',
'-Winvalid-pch',
'-Wno-unused-local-typedefs',
'-std=c++11',
'-x', 'c++',
'-Iinclude',
] + python_config('--cflags').split()
# Set this to the absolute path to the folder (NOT the file!) containing the
# compile_commands.json file to use that instead of 'flags'. See here for
# more details: http://clang.llvm.org/docs/JSONCompilationDatabase.html
#
# Most projects will NOT need to set this to anything; you can just change the
# 'flags' list of compilation flags.
compilation_database_folder = ''
if os.path.exists( compilation_database_folder ):
database = ycm_core.CompilationDatabase( compilation_database_folder )
else:
database = None
SOURCE_EXTENSIONS = [ '.cpp', '.cxx', '.cc', '.c', '.m', '.mm' ]
def DirectoryOfThisScript():
return os.path.dirname( os.path.abspath( __file__ ) )
def MakeRelativePathsInFlagsAbsolute( flags, working_directory ):
if not working_directory:
return list( flags )
new_flags = []
make_next_absolute = False
path_flags = [ '-isystem', '-I', '-iquote', '--sysroot=' ]
for flag in flags:
new_flag = flag
if make_next_absolute:
make_next_absolute = False
if not flag.startswith( '/' ):
new_flag = os.path.join( working_directory, flag )
for path_flag in path_flags:
if flag == path_flag:
make_next_absolute = True
break
if flag.startswith( path_flag ):
path = flag[ len( path_flag ): ]
new_flag = path_flag + os.path.join( working_directory, path )
break
if new_flag:
new_flags.append( new_flag )
return new_flags
def IsHeaderFile( filename ):
extension = os.path.splitext( filename )[ 1 ]
return extension in [ '.h', '.hxx', '.hpp', '.hh' ]
def GetCompilationInfoForFile( filename ):
# The compilation_commands.json file generated by CMake does not have entries
# for header files. So we do our best by asking the db for flags for a
# corresponding source file, if any. If one exists, the flags for that file
# should be good enough.
if IsHeaderFile( filename ):
basename = os.path.splitext( filename )[ 0 ]
for extension in SOURCE_EXTENSIONS:
replacement_file = basename + extension
if os.path.exists( replacement_file ):
compilation_info = database.GetCompilationInfoForFile(
replacement_file )
if compilation_info.compiler_flags_:
return compilation_info
return None
return database.GetCompilationInfoForFile( filename )
# This is the entry point; this function is called by ycmd to produce flags for
# a file.
def FlagsForFile( filename, **kwargs ):
if database:
# Bear in mind that compilation_info.compiler_flags_ does NOT return a
# python list, but a "list-like" StringVec object
compilation_info = GetCompilationInfoForFile( filename )
if not compilation_info:
return None
final_flags = MakeRelativePathsInFlagsAbsolute(
compilation_info.compiler_flags_,
compilation_info.compiler_working_dir_ )
else:
relative_to = DirectoryOfThisScript()
final_flags = MakeRelativePathsInFlagsAbsolute( flags, relative_to )
return {
'flags': final_flags,
'do_cache': True
}
CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags)
LDFLAGS = $(shell python3-config --ldflags)
DEPS = lanms.h $(shell find include -xtype f)
CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp
LIB_SO = adaptor.so
$(LIB_SO): $(CXX_SOURCES) $(DEPS)
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC
clean:
rm -rf $(LIB_SO)
import subprocess
import os
import numpy as np
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value
raise RuntimeError('Cannot compile lanms: {}'.format(BASE_DIR))
def merge_quadrangle_n9(polys, thres=0.3, precision=10000):
from .adaptor import merge_quadrangle_n9 as nms_impl
if len(polys) == 0:
return np.array([], dtype='float32')
p = polys.copy()
p[:,:8] *= precision
ret = np.array(nms_impl(p, thres), dtype='float32')
ret[:,:8] /= precision
return ret
import numpy as np
from . import merge_quadrangle_n9
if __name__ == '__main__':
# unit square with confidence 1
q = np.array([0, 0, 0, 1, 1, 1, 1, 0, 1], dtype='float32')
print(merge_quadrangle_n9(np.array([q, q + 0.1, q + 2])))
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "lanms.h"
namespace py = pybind11;
namespace lanms_adaptor {
std::vector<std::vector<float>> polys2floats(const std::vector<lanms::Polygon> &polys) {
std::vector<std::vector<float>> ret;
for (size_t i = 0; i < polys.size(); i ++) {
auto &p = polys[i];
auto &poly = p.poly;
ret.emplace_back(std::vector<float>{
float(poly[0].X), float(poly[0].Y),
float(poly[1].X), float(poly[1].Y),
float(poly[2].X), float(poly[2].Y),
float(poly[3].X), float(poly[3].Y),
float(p.score),
});
}
return ret;
}
/**
*
* \param quad_n9 an n-by-9 numpy array, where first 8 numbers denote the
* quadrangle, and the last one is the score
* \param iou_threshold two quadrangles with iou score above this threshold
* will be merged
*
* \return an n-by-9 numpy array, the merged quadrangles
*/
std::vector<std::vector<float>> merge_quadrangle_n9(
py::array_t<float, py::array::c_style | py::array::forcecast> quad_n9,
float iou_threshold) {
auto pbuf = quad_n9.request();
if (pbuf.ndim != 2 || pbuf.shape[1] != 9)
throw std::runtime_error("quadrangles must have a shape of (n, 9)");
auto n = pbuf.shape[0];
auto ptr = static_cast<float *>(pbuf.ptr);
return polys2floats(lanms::merge_quadrangle_n9(ptr, n, iou_threshold));
}
}
PYBIND11_PLUGIN(adaptor) {
py::module m("adaptor", "NMS");
m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9,
"merge quadrangels");
return m.ptr();
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment