Commit e7ad27c3 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix conflicts

parents c0b4cefd 91f5ab5c
This diff is collapsed.
......@@ -18,13 +18,15 @@ __all__ = ['build_head']
def build_head(config):
# det head
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
# rec head
from .rec_ctc_head import CTCHead
# cls head
from .cls_head import ClsHead
support_dict = ['DBHead', 'CTCHead', 'ClsHead']
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class EASTHead(nn.Layer):
"""
"""
def __init__(self, in_channels, model_name, **kwargs):
super(EASTHead, self).__init__()
self.model_name = model_name
if self.model_name == "large":
num_outputs = [128, 64, 1, 8]
else:
num_outputs = [64, 32, 1, 8]
self.det_conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=num_outputs[0],
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="det_head1")
self.det_conv2 = ConvBNLayer(
in_channels=num_outputs[0],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="det_head2")
self.score_conv = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name="f_score")
self.geo_conv = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name="f_geo")
def forward(self, x):
f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det)
f_score = F.sigmoid(f_score)
f_geo = self.geo_conv(f_det)
f_geo = (F.sigmoid(f_geo) - 0.5) * 2 * 800
pred = {'f_score': f_score, 'f_geo': f_geo}
return pred
This diff is collapsed.
......@@ -16,8 +16,10 @@ __all__ = ['build_neck']
def build_neck(config):
from .db_fpn import DBFPN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
support_dict = ['DBFPN', 'SequenceEncoder']
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
This diff is collapsed.
This diff is collapsed.
......@@ -24,11 +24,13 @@ __all__ = ['build_post_process']
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
from .cls_postprocess import ClsPostProcess
support_dict = [
'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
]
config = copy.deepcopy(config)
......
This diff is collapsed.
This diff is collapsed.
......@@ -27,7 +27,7 @@ class BaseRecLabelDecode(object):
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, self.character_str)
support_character_type, character_type)
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
......
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import logging
logger = logging.getLogger(__name__)
def check_config_params(config, config_name, params):
for param in params:
if param not in config:
err = "param %s didn't find in %s!" % (param, config_name)
assert False, err
return
......@@ -230,10 +230,10 @@ def draw_ocr_box_txt(image,
box[2][1], box[3][0], box[3][1]
],
outline=color)
box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][
1]) ** 2)
box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][
1]) ** 2)
box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
1])**2)
box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
1])**2)
if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
......@@ -260,7 +260,6 @@ def str_count(s):
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
......@@ -295,7 +294,6 @@ def text_visual(texts,
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
"""
if scores is not None:
assert len(texts) == len(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment