Unverified Commit 0458f0cc authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #13 from PaddlePaddle/dygraph

Dygraph
parents 04b0318b 836839bb
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .det_basic_loss import DiceLoss
class EASTLoss(nn.Layer):
"""
"""
def __init__(self,
eps=1e-6,
**kwargs):
super(EASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps)
def forward(self, predicts, labels):
l_score, l_geo, l_mask = labels[1:]
f_score = predicts['f_score']
f_geo = predicts['f_geo']
dice_loss = self.dice_loss(f_score, l_score, l_mask)
#smoooth_l1_loss
channels = 8
l_geo_split = paddle.split(
l_geo, num_or_sections=channels + 1, axis=1)
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
smooth_l1 = 0
for i in range(0, channels):
geo_diff = l_geo_split[i] - f_geo_split[i]
abs_geo_diff = paddle.abs(geo_diff)
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
out_loss = l_geo_split[-1] / channels * in_loss * l_score
smooth_l1 += out_loss
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
dice_loss = dice_loss * 0.01
total_loss = dice_loss + smooth_l1_loss
losses = {"loss":total_loss, \
"dice_loss":dice_loss,\
"smooth_l1_loss":smooth_l1_loss}
return losses
This diff is collapsed.
......@@ -16,7 +16,7 @@ from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transform import build_transform
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
......
......@@ -19,6 +19,7 @@ def build_backbone(config, model_type):
if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
......
......@@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None):
class MobileNetV3(nn.Layer):
def __init__(self, in_channels=3, model_name='large', scale=0.5, **kwargs):
def __init__(self,
in_channels=3,
model_name='large',
scale=0.5,
disable_se=False,
**kwargs):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
......@@ -103,6 +111,7 @@ class MobileNetV3(nn.Layer):
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
if s == 2 and i > 2:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
......
This diff is collapsed.
......@@ -18,13 +18,15 @@ __all__ = ['build_head']
def build_head(config):
# det head
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
# rec head
from .rec_ctc_head import CTCHead
# cls head
from .cls_head import ClsHead
support_dict = ['DBHead', 'CTCHead', 'ClsHead']
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment