"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f5487628924ae591528ca472060bb5713d36bd32"
Unverified Commit 85aeae71 authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge pull request #3002 from littletomatodonkey/dyg/add_distillation

add distillation
parents d93a445d 95d07675
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: false
distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 5
regularizer:
name: L2
factor: 1.0e-05
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 0.00001
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 0.00001
Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8
...@@ -13,28 +13,37 @@ ...@@ -13,28 +13,37 @@
# limitations under the License. # limitations under the License.
import copy import copy
import paddle
import paddle.nn as nn
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
def build_loss(config): # rec loss
# det loss from .rec_ctc_loss import CTCLoss
from .det_db_loss import DBLoss from .rec_att_loss import AttentionLoss
from .det_east_loss import EASTLoss from .rec_srn_loss import SRNLoss
from .det_sast_loss import SASTLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
# rec loss # basic loss function
from .rec_ctc_loss import CTCLoss from .basic_loss import DistanceLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# cls loss # combined loss function
from .cls_loss import ClsLoss from .combined_loss import CombinedLoss
# e2e loss
from .e2e_pg_loss import PGLoss def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss'] 'SRNLoss', 'PGLoss', 'CombinedLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format( assert module_name in support_dict, Exception('loss only support {}'.format(
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
def __init__(self, act=None):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=-1)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
"""
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
def forward(self, x, y):
return self.loss_func(x, y)
...@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer): ...@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
super(ClsLoss, self).__init__() super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean') self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def __call__(self, predicts, batch): def forward(self, predicts, batch):
label = batch[1] label = batch[1]
loss = self.loss_func(input=predicts, label=label) loss = self.loss_func(input=predicts, label=label)
return {'loss': loss} return {'loss': loss}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss
class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""
def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
class DistillationDMLLoss(DMLLoss):
"""
"""
def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
class DistillationCTCLoss(CTCLoss):
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
...@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer): ...@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
super(CTCLoss, self).__init__() super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none') self.loss_func = nn.CTCLoss(blank=0, reduction='none')
def __call__(self, predicts, batch): def forward(self, predicts, batch):
predicts = predicts.transpose((1, 0, 2)) predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64') preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
......
...@@ -19,20 +19,23 @@ from __future__ import unicode_literals ...@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import copy import copy
__all__ = ['build_metric'] __all__ = ["build_metric"]
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] def build_metric(config):
support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop("name")
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'metric only support {}'.format(support_dict)) "metric only support {}".format(support_dict))
module_class = eval(module_name)(**config) module_class = eval(module_name)(**config)
return module_class return module_class
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import copy
from .rec_metric import RecMetric
from .det_metric import DetMetric
from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric
class DistillationMetric(object):
def __init__(self,
key=None,
base_metric_name="RecMetric",
main_indicator='acc',
**kwargs):
self.main_indicator = main_indicator
self.key = key
self.main_indicator = main_indicator
self.base_metric_name = base_metric_name
self.kwargs = kwargs
self.metrics = None
def _init_metrcis(self, preds):
self.metrics = dict()
mod = importlib.import_module(__name__)
for key in preds:
self.metrics[key] = getattr(mod, self.base_metric_name)(
main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs):
assert isinstance(preds, dict)
if self.metrics is None:
self._init_metrcis(preds)
output = dict()
for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
output = dict()
for key in self.metrics:
metric = self.metrics[key].get_metric()
# main indicator
if key == self.key:
output.update(metric)
else:
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def reset(self):
for key in self.metrics:
self.metrics[key].reset()
...@@ -13,12 +13,20 @@ ...@@ -13,12 +13,20 @@
# limitations under the License. # limitations under the License.
import copy import copy
import importlib
from .base_model import BaseModel
from .distillation_model import DistillationModel
__all__ = ['build_model'] __all__ = ['build_model']
def build_model(config):
from .base_model import BaseModel
def build_model(config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_class = BaseModel(config) if not "name" in config:
return module_class arch = BaseModel(config)
\ No newline at end of file else:
name = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config)
return arch
...@@ -32,7 +32,6 @@ class BaseModel(nn.Layer): ...@@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
config (dict): the super parameters for module. config (dict): the super parameters for module.
""" """
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3) in_channels = config.get('in_channels', 3)
model_type = config['model_type'] model_type = config['model_type']
# build transfrom, # build transfrom,
...@@ -68,14 +67,23 @@ class BaseModel(nn.Layer): ...@@ -68,14 +67,23 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None): def forward(self, x, data=None):
y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
y["backbone_out"] = x
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x
if data is None: if data is None:
x = self.head(x) x = self.head(x)
else: else:
x = self.head(x, data) x = self.head(x, data)
y["head_out"] = x
if self.return_all_feats:
return y
else:
return x return x
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import init_model
__all__ = ['DistillationModel']
class DistillationModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super().__init__()
self.model_list = []
self.model_name_list = []
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
pretrained = None
if "freeze_params" in model_config:
freeze_params = model_config.pop("freeze_params")
if "pretrained" in model_config:
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
init_model(model, path=pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x)
return result_dict
...@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer): ...@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish')
name='conv1')
self.stages = [] self.stages = []
self.out_channels = [] self.out_channels = []
...@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer): ...@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
kernel_size=k, kernel_size=k,
stride=s, stride=s,
use_se=se, use_se=se,
act=nl, act=nl))
name="conv" + str(i + 2)))
inplanes = make_divisible(scale * c) inplanes = make_divisible(scale * c)
i += 1 i += 1
block_list.append( block_list.append(
...@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer): ...@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish'))
name='conv_last'))
self.stages.append(nn.Sequential(*block_list)) self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages): for i, stage in enumerate(self.stages):
...@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer): ...@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
padding, padding,
groups=1, groups=1,
if_act=True, if_act=True,
act=None, act=None):
name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.if_act = if_act self.if_act = if_act
self.act = act self.act = act
...@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer): ...@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
stride=stride, stride=stride,
padding=padding, padding=padding,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False) bias_attr=False)
self.bn = nn.BatchNorm( self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
num_channels=out_channels,
act=None,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer): ...@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
kernel_size, kernel_size,
stride, stride,
use_se, use_se,
act=None, act=None):
name=''):
super(ResidualUnit, self).__init__() super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se self.if_se = use_se
...@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer): ...@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
if_act=True, if_act=True,
act=act, act=act)
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer( self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels, in_channels=mid_channels,
out_channels=mid_channels, out_channels=mid_channels,
...@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer): ...@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
padding=int((kernel_size - 1) // 2), padding=int((kernel_size - 1) // 2),
groups=mid_channels, groups=mid_channels,
if_act=True, if_act=True,
act=act, act=act)
name=name + "_depthwise")
if self.if_se: if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se") self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer( self.linear_conv = ConvBNLayer(
in_channels=mid_channels, in_channels=mid_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer): ...@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
if_act=False, if_act=False,
act=None, act=None)
name=name + "_linear")
def forward(self, inputs): def forward(self, inputs):
x = self.expand_conv(inputs) x = self.expand_conv(inputs)
...@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer): ...@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
class SEModule(nn.Layer): class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4, name=""): def __init__(self, in_channels, reduction=4):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1) self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
...@@ -266,17 +251,13 @@ class SEModule(nn.Layer): ...@@ -266,17 +251,13 @@ class SEModule(nn.Layer):
out_channels=in_channels // reduction, out_channels=in_channels // reduction,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0)
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = nn.Conv2D( self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction, in_channels=in_channels // reduction,
out_channels=in_channels, out_channels=in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0)
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs): def forward(self, inputs):
outputs = self.avg_pool(inputs) outputs = self.avg_pool(inputs)
......
...@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer): ...@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish')
name='conv1')
i = 0 i = 0
block_list = [] block_list = []
inplanes = make_divisible(inplanes * scale) inplanes = make_divisible(inplanes * scale)
...@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer): ...@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
kernel_size=k, kernel_size=k,
stride=s, stride=s,
use_se=se, use_se=se,
act=nl, act=nl))
name='conv' + str(i + 2)))
inplanes = make_divisible(scale * c) inplanes = make_divisible(scale * c)
i += 1 i += 1
self.blocks = nn.Sequential(*block_list) self.blocks = nn.Sequential(*block_list)
...@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer): ...@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish')
name='conv_last')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze) self.out_channels = make_divisible(scale * cls_ch_squeeze)
......
...@@ -23,10 +23,10 @@ import paddle.nn.functional as F ...@@ -23,10 +23,10 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
def get_bias_attr(k, name): def get_bias_attr(k):
stdv = 1.0 / math.sqrt(k * 1.0) stdv = 1.0 / math.sqrt(k * 1.0)
initializer = paddle.nn.initializer.Uniform(-stdv, stdv) initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr") bias_attr = ParamAttr(initializer=initializer)
return bias_attr return bias_attr
...@@ -38,18 +38,14 @@ class Head(nn.Layer): ...@@ -38,18 +38,14 @@ class Head(nn.Layer):
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr(name=name_list[0] + '.w_0'), weight_attr=ParamAttr(),
bias_attr=False) bias_attr=False)
self.conv_bn1 = nn.BatchNorm( self.conv_bn1 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name_list[1] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)), initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=name_list[1] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[1] + '.w_1',
moving_variance_name=name_list[1] + '.w_2',
act='relu') act='relu')
self.conv2 = nn.Conv2DTranspose( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
...@@ -57,19 +53,14 @@ class Head(nn.Layer): ...@@ -57,19 +53,14 @@ class Head(nn.Layer):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[2] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) bias_attr=get_bias_attr(in_channels // 4))
self.conv_bn2 = nn.BatchNorm( self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name_list[3] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)), initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=name_list[3] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[3] + '.w_1',
moving_variance_name=name_list[3] + '.w_2',
act="relu") act="relu")
self.conv3 = nn.Conv2DTranspose( self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
...@@ -77,10 +68,8 @@ class Head(nn.Layer): ...@@ -77,10 +68,8 @@ class Head(nn.Layer):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[4] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), bias_attr=get_bias_attr(in_channels // 4), )
)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
......
...@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn ...@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
from paddle.nn import functional as F from paddle.nn import functional as F
def get_para_bias_attr(l2_decay, k, name): def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay) regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0) stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv) initializer = nn.initializer.Uniform(-stdv, stdv)
weight_attr = ParamAttr( weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
regularizer=regularizer, initializer=initializer, name=name + "_w_attr") bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [weight_attr, bias_attr] return [weight_attr, bias_attr]
...@@ -38,13 +36,12 @@ class CTCHead(nn.Layer): ...@@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
super(CTCHead, self).__init__() super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr( weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc') l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear( self.fc = nn.Linear(
in_channels, in_channels,
out_channels, out_channels,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr)
name='ctc_fc')
self.out_channels = out_channels self.out_channels = out_channels
def forward(self, x, labels=None): def forward(self, x, labels=None):
......
...@@ -32,61 +32,53 @@ class DBFPN(nn.Layer): ...@@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
in_channels=in_channels[0], in_channels=in_channels[0],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_51.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in3_conv = nn.Conv2D( self.in3_conv = nn.Conv2D(
in_channels=in_channels[1], in_channels=in_channels[1],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_50.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in4_conv = nn.Conv2D( self.in4_conv = nn.Conv2D(
in_channels=in_channels[2], in_channels=in_channels[2],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_49.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in5_conv = nn.Conv2D( self.in5_conv = nn.Conv2D(
in_channels=in_channels[3], in_channels=in_channels[3],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_48.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p5_conv = nn.Conv2D( self.p5_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_52.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p4_conv = nn.Conv2D( self.p4_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_53.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p3_conv = nn.Conv2D( self.p3_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_54.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p2_conv = nn.Conv2D( self.p2_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_55.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
def forward(self, x): def forward(self, x):
......
...@@ -21,18 +21,19 @@ import copy ...@@ -21,18 +21,19 @@ import copy
__all__ = ['build_post_process'] __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -23,6 +23,8 @@ import six ...@@ -23,6 +23,8 @@ import six
import paddle import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
...@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger): ...@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): def init_model(config, model, optimizer=None, lr_scheduler=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB
if weight_name in pre_state_dict.keys():
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key:
# delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[
weight_name].squeeze()
# change axis
if len(pre_state_dict[weight_name].shape) > 1:
pre_state_dict[weight_name] = pre_state_dict[
weight_name].transpose((1, 0))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_state_dict(param_state_dict)
return
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
logger = get_logger()
global_config = config['Global'] global_config = config['Global']
checkpoints = global_config.get('checkpoints') checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
...@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {}) best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict: if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
load_static_weights = global_config.get('load_static_weights', False)
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list): for pretrained in pretrained_model:
load_static_weights = [load_static_weights] * len(pretrained_model) if not (os.path.isdir(pretrained) or
for idx, pretrained in enumerate(pretrained_model): os.path.exists(pretrained + '.pdparams')):
load_static = load_static_weights[idx] raise ValueError("Model pretrain path {} does not "
load_dygraph_pretrain( "exists.".format(pretrained))
model, logger, path=pretrained, load_static_weights=load_static) param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format( logger.info("load pretrained model from {}".format(
pretrained_model)) pretrained_model))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment