Commit f20f6d2d authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'upstream/dygraph' into dy1

parents 647db30f acd479ea
...@@ -33,7 +33,7 @@ import paddle.distributed as dist ...@@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp) ...@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp) signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test' assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test." ], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger) dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']
shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers'] num_workers = loader_config['num_workers']
if 'use_shared_memory' in loader_config.keys(): if 'use_shared_memory' in loader_config.keys():
use_shared_memory = loader_config['use_shared_memory'] use_shared_memory = loader_config['use_shared_memory']
...@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger): ...@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger):
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
else: else:
#Distribute data to single card #Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
data_loader = DataLoader( data_loader = DataLoader(
......
...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap ...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy as np import numpy as np
import string
class ClsLabelEncode(object): class ClsLabelEncode(object):
...@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object): ...@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object):
character_type='ch', character_type='ch',
use_space_char=False): use_space_char=False):
support_character_type = [ support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type in ["ch", "french", "german", "japan", "korean"]: elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = "" self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object): ...@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object):
if use_space_char: if use_space_char:
self.character_str += " " self.character_str += " "
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
import string
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
self.character_type = character_type self.character_type = character_type
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
...@@ -213,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -213,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
char_num = len(self.character_str)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [char_num] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
...@@ -12,20 +12,6 @@ ...@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import cv2 import cv2
import numpy as np import numpy as np
...@@ -77,6 +63,26 @@ class RecResizeImg(object): ...@@ -77,6 +63,26 @@ class RecResizeImg(object):
return data return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
self.num_heads = num_heads
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_srn(img, self.image_shape)
data['image'] = norm_img
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
data['encoder_word_pos'] = encoder_word_pos
data['gsrm_word_pos'] = gsrm_word_pos
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
return data
def resize_norm_img(img, image_shape): def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
h = img.shape[0] h = img.shape[0]
...@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape): ...@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape): def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape # todo: change to 0 and modified image shape
max_wh_ratio = 0 max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1] h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio) max_wh_ratio = max(max_wh_ratio, ratio)
...@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape): ...@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im return padding_im
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[num_heads, 1, 1]) * [-1e9]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def flag(): def flag():
""" """
flag flag
......
...@@ -21,7 +21,7 @@ from .imaug import transform, create_operators ...@@ -21,7 +21,7 @@ from .imaug import transform, create_operators
class LMDBDateSet(Dataset): class LMDBDateSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(LMDBDateSet, self).__init__() super(LMDBDateSet, self).__init__()
global_config = config['Global'] global_config = config['Global']
......
...@@ -20,7 +20,7 @@ from .imaug import transform, create_operators ...@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset): class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
...@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset): ...@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
...@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset): ...@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
random.seed(self.seed)
lines = random.sample(lines, lines = random.sample(lines,
round(len(lines) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(lines) data_lines.extend(lines)
...@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset): ...@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
......
...@@ -23,11 +23,14 @@ def build_loss(config): ...@@ -23,11 +23,14 @@ def build_loss(config):
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_srn_loss import SRNLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SRNLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
def forward(self, predicts, batch):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = batch[1]
casted_label = paddle.cast(x=label, dtype='int64')
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
cost_word = self.loss_func(word_predict, label=casted_label)
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
cost_vsfd = self.loss_func(predict, label=casted_label)
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
...@@ -33,8 +33,6 @@ class RecMetric(object): ...@@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target: if pred == target:
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1
# if all_num < 10 and kwargs.get('show_str', False):
# print('{} -> {}'.format(pred, target))
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
...@@ -50,7 +48,7 @@ class RecMetric(object): ...@@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = self.correct_num / self.all_num acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
......
...@@ -68,11 +68,14 @@ class BaseModel(nn.Layer): ...@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
def forward(self, x): def forward(self, x, data=None):
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
x = self.head(x) if data is None:
x = self.head(x)
else:
x = self.head(x, data)
return x return x
...@@ -24,7 +24,8 @@ def build_backbone(config, model_type): ...@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls': elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN'] from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -58,15 +58,15 @@ class MobileNetV3(nn.Layer): ...@@ -58,15 +58,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, 'relu', 2], [5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2], [3, 240, 80, False, 'hardswish', 2],
[3, 200, 80, False, 'hard_swish', 1], [3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hard_swish', 1], [3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hard_swish', 1], [3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hard_swish', 1], [3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hard_swish', 1], [3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hard_swish', 2], [5, 672, 160, True, 'hardswish', 2],
[5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hardswish', 1],
] ]
cls_ch_squeeze = 960 cls_ch_squeeze = 960
elif model_name == "small": elif model_name == "small":
...@@ -75,14 +75,14 @@ class MobileNetV3(nn.Layer): ...@@ -75,14 +75,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, 'relu', 2], [3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2], [3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1], [3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2], [5, 96, 40, True, 'hardswish', 2],
[5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hard_swish', 1], [5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hard_swish', 1], [5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hard_swish', 2], [5, 288, 96, True, 'hardswish', 2],
[5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hardswish', 1],
] ]
cls_ch_squeeze = 576 cls_ch_squeeze = 576
else: else:
...@@ -102,7 +102,7 @@ class MobileNetV3(nn.Layer): ...@@ -102,7 +102,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
groups=1, groups=1,
if_act=True, if_act=True,
act='hard_swish', act='hardswish',
name='conv1') name='conv1')
self.stages = [] self.stages = []
...@@ -112,7 +112,8 @@ class MobileNetV3(nn.Layer): ...@@ -112,7 +112,8 @@ class MobileNetV3(nn.Layer):
inplanes = make_divisible(inplanes * scale) inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg: for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se se = se and not self.disable_se
if s == 2 and i > 2: start_idx = 2 if model_name == 'large' else 0
if s == 2 and i > start_idx:
self.out_channels.append(inplanes) self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list)) self.stages.append(nn.Sequential(*block_list))
block_list = [] block_list = []
...@@ -137,7 +138,7 @@ class MobileNetV3(nn.Layer): ...@@ -137,7 +138,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
groups=1, groups=1,
if_act=True, if_act=True,
act='hard_swish', act='hardswish',
name='conv_last')) name='conv_last'))
self.stages.append(nn.Sequential(*block_list)) self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
...@@ -191,10 +192,11 @@ class ConvBNLayer(nn.Layer): ...@@ -191,10 +192,11 @@ class ConvBNLayer(nn.Layer):
if self.if_act: if self.if_act:
if self.act == "relu": if self.act == "relu":
x = F.relu(x) x = F.relu(x)
elif self.act == "hard_swish": elif self.act == "hardswish":
x = F.activation.hard_swish(x) x = F.hardswish(x)
else: else:
print("The activation function is selected incorrectly.") print("The activation function({}) is selected incorrectly.".
format(self.act))
exit() exit()
return x return x
...@@ -281,5 +283,5 @@ class SEModule(nn.Layer): ...@@ -281,5 +283,5 @@ class SEModule(nn.Layer):
outputs = self.conv1(outputs) outputs = self.conv1(outputs)
outputs = F.relu(outputs) outputs = F.relu(outputs)
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = F.activation.hard_sigmoid(outputs) outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs return inputs * outputs
...@@ -51,15 +51,15 @@ class MobileNetV3(nn.Layer): ...@@ -51,15 +51,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, 'relu', (large_stride[2], 1)], [5, 72, 40, True, 'relu', (large_stride[2], 1)],
[5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 1], [3, 240, 80, False, 'hardswish', 1],
[3, 200, 80, False, 'hard_swish', 1], [3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hard_swish', 1], [3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hard_swish', 1], [3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hard_swish', 1], [3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hard_swish', 1], [3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)], [5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
[5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hardswish', 1],
] ]
cls_ch_squeeze = 960 cls_ch_squeeze = 960
elif model_name == "small": elif model_name == "small":
...@@ -68,14 +68,14 @@ class MobileNetV3(nn.Layer): ...@@ -68,14 +68,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, 'relu', (small_stride[0], 1)], [3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (small_stride[1], 1)], [3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1], [3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)], [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
[5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hard_swish', 1], [5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hard_swish', 1], [5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)], [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
[5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hardswish', 1],
] ]
cls_ch_squeeze = 576 cls_ch_squeeze = 576
else: else:
...@@ -96,7 +96,7 @@ class MobileNetV3(nn.Layer): ...@@ -96,7 +96,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
groups=1, groups=1,
if_act=True, if_act=True,
act='hard_swish', act='hardswish',
name='conv1') name='conv1')
i = 0 i = 0
block_list = [] block_list = []
...@@ -124,7 +124,7 @@ class MobileNetV3(nn.Layer): ...@@ -124,7 +124,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
groups=1, groups=1,
if_act=True, if_act=True,
act='hard_swish', act='hardswish',
name='conv_last') name='conv_last')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import paddle.fluid as fluid
import paddle
import numpy as np
__all__ = ["ResNetFPN"]
class ResNetFPN(nn.Layer):
def __init__(self, in_channels=1, layers=50, **kwargs):
super(ResNetFPN, self).__init__()
supported_layers = {
18: {
'depth': [2, 2, 2, 2],
'block_class': BasicBlock
},
34: {
'depth': [3, 4, 6, 3],
'block_class': BasicBlock
},
50: {
'depth': [3, 4, 6, 3],
'block_class': BottleneckBlock
},
101: {
'depth': [3, 4, 23, 3],
'block_class': BottleneckBlock
},
152: {
'depth': [3, 8, 36, 3],
'block_class': BottleneckBlock
}
}
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
num_filters = [64, 128, 256, 512]
self.depth = supported_layers[layers]['depth']
self.F = []
self.conv = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=2,
act="relu",
name="conv1")
self.block_list = []
in_ch = 64
if layers >= 50:
for block in range(len(self.depth)):
for i in range(self.depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
block_list = self.add_sublayer(
"bottleneckBlock_{}_{}".format(block, i),
BottleneckBlock(
in_channels=in_ch,
out_channels=num_filters[block],
stride=stride_list[block] if i == 0 else 1,
name=conv_name))
in_ch = num_filters[block] * 4
self.block_list.append(block_list)
self.F.append(block_list)
else:
for block in range(len(self.depth)):
for i in range(self.depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
if i == 0 and block != 0:
stride = (2, 1)
else:
stride = (1, 1)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
in_channels=in_ch,
out_channels=num_filters[block],
stride=stride_list[block] if i == 0 else 1,
is_first=block == i == 0,
name=conv_name))
in_ch = basic_block.out_channels
self.block_list.append(basic_block)
out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
self.base_block = []
self.conv_trans = []
self.bn_block = []
for i in [-2, -3]:
in_channels = out_ch_list[i + 1] + out_ch_list[i]
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_0".format(i),
nn.Conv2D(
in_channels=in_channels,
out_channels=out_ch_list[i],
kernel_size=1,
weight_attr=ParamAttr(trainable=True),
bias_attr=ParamAttr(trainable=True))))
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_1".format(i),
nn.Conv2D(
in_channels=out_ch_list[i],
out_channels=out_ch_list[i],
kernel_size=3,
padding=1,
weight_attr=ParamAttr(trainable=True),
bias_attr=ParamAttr(trainable=True))))
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_2".format(i),
nn.BatchNorm(
num_channels=out_ch_list[i],
act="relu",
param_attr=ParamAttr(trainable=True),
bias_attr=ParamAttr(trainable=True))))
self.base_block.append(
self.add_sublayer(
"F_{}_base_block_3".format(i),
nn.Conv2D(
in_channels=out_ch_list[i],
out_channels=512,
kernel_size=1,
bias_attr=ParamAttr(trainable=True),
weight_attr=ParamAttr(trainable=True))))
self.out_channels = 512
def __call__(self, x):
x = self.conv(x)
fpn_list = []
F = []
for i in range(len(self.depth)):
fpn_list.append(np.sum(self.depth[:i + 1]))
for i, block in enumerate(self.block_list):
x = block(x)
for number in fpn_list:
if i + 1 == number:
F.append(x)
base = F[-1]
j = 0
for i, block in enumerate(self.base_block):
if i % 3 == 0 and i < 6:
j = j + 1
b, c, w, h = F[-j - 1].shape
if [w, h] == list(base.shape[2:]):
base = base
else:
base = self.conv_trans[j - 1](base)
base = self.bn_block[j - 1](base)
base = paddle.concat([base, F[-j - 1]], axis=1)
base = block(base)
return base
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=2 if stride == (1, 1) else kernel_size,
dilation=2 if stride == (1, 1) else 1,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
bias_attr=False, )
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name=name + '.output.1.w_0'),
bias_attr=ParamAttr(name=name + '.output.1.b_0'),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
def __call__(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class ShortCut(nn.Layer):
def __init__(self, in_channels, out_channels, stride, name, is_first=False):
super(ShortCut, self).__init__()
self.use_conv = True
if in_channels != out_channels or stride != 1 or is_first == True:
if stride == (1, 1):
self.conv = ConvBNLayer(
in_channels, out_channels, 1, 1, name=name)
else: # stride==(2,2)
self.conv = ConvBNLayer(
in_channels, out_channels, 1, stride, name=name)
else:
self.use_conv = False
def forward(self, x):
if self.use_conv:
x = self.conv(x)
return x
class BottleneckBlock(nn.Layer):
def __init__(self, in_channels, out_channels, stride, name):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
self.short = ShortCut(
in_channels=in_channels,
out_channels=out_channels * 4,
stride=stride,
is_first=False,
name=name + "_branch1")
self.out_channels = out_channels * 4
def forward(self, x):
y = self.conv0(x)
y = self.conv1(y)
y = self.conv2(y)
y = y + self.short(x)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self, in_channels, out_channels, stride, name, is_first):
super(BasicBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
self.short = ShortCut(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
is_first=is_first,
name=name + "_branch1")
self.out_channels = out_channels
def forward(self, x):
y = self.conv0(x)
y = self.conv1(y)
y = y + self.short(x)
return F.relu(y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment