Commit 2b3f89f0 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

parents 591b92e8 882e6e54
...@@ -17,9 +17,14 @@ import string ...@@ -17,9 +17,14 @@ import string
class RecMetric(object): class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs): def __init__(self,
main_indicator='acc',
is_filter=False,
ignore_space=True,
**kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.is_filter = is_filter self.is_filter = is_filter
self.ignore_space = ignore_space
self.eps = 1e-5 self.eps = 1e-5
self.reset() self.reset()
...@@ -34,8 +39,9 @@ class RecMetric(object): ...@@ -34,8 +39,9 @@ class RecMetric(object):
all_num = 0 all_num = 0
norm_edit_dis = 0.0 norm_edit_dis = 0.0
for (pred, pred_conf), (target, _) in zip(preds, labels): for (pred, pred_conf), (target, _) in zip(preds, labels):
pred = pred.replace(" ", "") if self.ignore_space:
target = target.replace(" ", "") pred = pred.replace(" ", "")
target = target.replace(" ", "")
if self.is_filter: if self.is_filter:
pred = self._normalize_text(pred) pred = self._normalize_text(pred)
target = self._normalize_text(target) target = self._normalize_text(target)
......
...@@ -83,7 +83,11 @@ class BaseModel(nn.Layer): ...@@ -83,7 +83,11 @@ class BaseModel(nn.Layer):
y["neck_out"] = x y["neck_out"] = x
if self.use_head: if self.use_head:
x = self.head(x, targets=data) x = self.head(x, targets=data)
if isinstance(x, dict): # for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x) y.update(x)
else: else:
y["head_out"] = x y["head_out"] = x
......
...@@ -53,8 +53,8 @@ class DistillationModel(nn.Layer): ...@@ -53,8 +53,8 @@ class DistillationModel(nn.Layer):
self.model_list.append(self.add_sublayer(key, model)) self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key) self.model_name_list.append(key)
def forward(self, x): def forward(self, x, data=None):
result_dict = dict() result_dict = dict()
for idx, model_name in enumerate(self.model_name_list): for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x) result_dict[model_name] = self.model_list[idx](x, data)
return result_dict return result_dict
...@@ -31,9 +31,11 @@ def build_backbone(config, model_type): ...@@ -31,9 +31,11 @@ def build_backbone(config, model_type):
from .rec_resnet_aster import ResNet_ASTER from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN' "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
'SVTRNet'
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
......
...@@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer): ...@@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer):
class MobileNetV1Enhance(nn.Layer): class MobileNetV1Enhance(nn.Layer):
def __init__(self, in_channels=3, scale=0.5, **kwargs): def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
self.block_list = [] self.block_list = []
...@@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer): ...@@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer):
num_filters1=1024, num_filters1=1024,
num_filters2=1024, num_filters2=1024,
num_groups=1024, num_groups=1024,
stride=1, stride=last_conv_stride,
dw_size=5, dw_size=5,
padding=2, padding=2,
use_se=True, use_se=True,
...@@ -208,8 +213,10 @@ class MobileNetV1Enhance(nn.Layer): ...@@ -208,8 +213,10 @@ class MobileNetV1Enhance(nn.Layer):
self.block_list.append(conv6) self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list) self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale) self.out_channels = int(1024 * scale)
def forward(self, inputs): def forward(self, inputs):
......
This diff is collapsed.
...@@ -32,6 +32,7 @@ def build_head(config): ...@@ -32,6 +32,7 @@ def build_head(config):
from .rec_sar_head import SARHead from .rec_sar_head import SARHead
from .rec_aster_head import AsterHead from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
...@@ -44,7 +45,8 @@ def build_head(config): ...@@ -44,7 +45,8 @@ def build_head(config):
support_dict = [ support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead' 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead'
] ]
#table head #table head
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR
from .rec_ctc_head import CTCHead
from .rec_sar_head import SARHead
class MultiHead(nn.Layer):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop('head_list')
self.gtc_head = 'sar'
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
name = list(head_name)[0]
if name == 'SARHead':
# sar head
sar_args = self.head_list[idx][name]
self.sar_head = eval(name)(in_channels=in_channels, \
out_channels=out_channels_list['SARLabelDecode'], **sar_args)
elif name == 'CTCHead':
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
neck_args = self.head_list[idx][name]['Neck']
encoder_type = neck_args.pop('name')
self.encoder = encoder_type
self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \
encoder_type=encoder_type, **neck_args)
# ctc head
head_args = self.head_list[idx][name]['Head']
self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \
out_channels=out_channels_list['CTCLabelDecode'], **head_args)
else:
raise NotImplementedError(
'{} is not supported in MultiHead yet'.format(name))
def forward(self, x, targets=None):
ctc_encoder = self.ctc_encoder(x)
ctc_out = self.ctc_head(ctc_encoder, targets)
head_out = dict()
head_out['ctc'] = ctc_out
head_out['ctc_neck'] = ctc_encoder
# eval mode
if not self.training:
return ctc_out
if self.gtc_head == 'sar':
sar_out = self.sar_head(x, targets[1:])
head_out['sar'] = sar_out
return head_out
else:
return head_out
...@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder):
class SARHead(nn.Layer): class SARHead(nn.Layer):
def __init__(self, def __init__(self,
in_channels,
out_channels, out_channels,
enc_dim=512,
max_text_length=30,
enc_bi_rnn=False, enc_bi_rnn=False,
enc_drop_rnn=0.1, enc_drop_rnn=0.1,
enc_gru=False, enc_gru=False,
...@@ -358,14 +361,17 @@ class SARHead(nn.Layer): ...@@ -358,14 +361,17 @@ class SARHead(nn.Layer):
dec_gru=False, dec_gru=False,
d_k=512, d_k=512,
pred_dropout=0.1, pred_dropout=0.1,
max_text_length=30,
pred_concat=True, pred_concat=True,
**kwargs): **kwargs):
super(SARHead, self).__init__() super(SARHead, self).__init__()
# encoder module # encoder module
self.encoder = SAREncoder( self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru) enc_bi_rnn=enc_bi_rnn,
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru,
d_model=in_channels,
d_enc=enc_dim)
# decoder module # decoder module
self.decoder = ParallelSARDecoder( self.decoder = ParallelSARDecoder(
...@@ -374,6 +380,8 @@ class SARHead(nn.Layer): ...@@ -374,6 +380,8 @@ class SARHead(nn.Layer):
dec_bi_rnn=dec_bi_rnn, dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn, dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru, dec_gru=dec_gru,
d_model=in_channels,
d_enc=enc_dim,
d_k=d_k, d_k=d_k,
pred_dropout=pred_dropout, pred_dropout=pred_dropout,
max_text_length=max_text_length, max_text_length=max_text_length,
...@@ -390,7 +398,7 @@ class SARHead(nn.Layer): ...@@ -390,7 +398,7 @@ class SARHead(nn.Layer):
label = paddle.to_tensor(label, dtype='int64') label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder( final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets) feat, holistic_feat, label, img_metas=targets)
if not self.training: else:
final_out = self.decoder( final_out = self.decoder(
feat, feat,
holistic_feat, holistic_feat,
......
...@@ -16,9 +16,11 @@ from __future__ import absolute_import ...@@ -16,9 +16,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
from paddle import nn from paddle import nn
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_
class Im2Seq(nn.Layer): class Im2Seq(nn.Layer):
...@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer): ...@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer):
return x return x
class EncoderWithSVTR(nn.Layer):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.,
qk_scale=None):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels, in_channels // 8, padding=1, act=nn.Swish)
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish)
self.svtr_block = nn.LayerList([
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer='Global',
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=nn.Swish,
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer='nn.LayerNorm',
epsilon=1e-05,
prenorm=False) for i in range(depth)
])
self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6)
self.conv3 = ConvBNLayer(
hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act=nn.Swish)
self.conv1x1 = ConvBNLayer(
in_channels // 8, dims, kernel_size=1, act=nn.Swish)
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).transpose([0, 2, 1])
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
z = self.conv3(z)
z = paddle.concat((h, z), axis=1)
z = self.conv1x1(self.conv4(z))
return z
class SequenceEncoder(nn.Layer): class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__() super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels) self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == 'reshape': if encoder_type == 'reshape':
self.only_reshape = True self.only_reshape = True
else: else:
support_encoder_dict = { support_encoder_dict = {
'reshape': Im2Seq, 'reshape': Im2Seq,
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN 'rnn': EncoderWithRNN,
'svtr': EncoderWithSVTR
} }
assert encoder_type in support_encoder_dict, '{} must in {}'.format( assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys()) encoder_type, support_encoder_dict.keys())
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type]( self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size) self.encoder_reshape.out_channels, **kwargs)
else:
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
self.out_channels = self.encoder.out_channels self.out_channels = self.encoder.out_channels
self.only_reshape = False self.only_reshape = False
def forward(self, x): def forward(self, x):
x = self.encoder_reshape(x) if self.encoder_type != 'svtr':
if not self.only_reshape: x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x) x = self.encoder(x)
return x x = self.encoder_reshape(x)
return x
...@@ -128,6 +128,8 @@ class STN_ON(nn.Layer): ...@@ -128,6 +128,8 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels self.out_channels = in_channels
def forward(self, image): def forward(self, image):
if len(image.shape)==5:
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate( stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
......
...@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer): ...@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer):
assert source_control_points.shape[2] == 2 assert source_control_points.shape[2] == 2
batch_size = paddle.shape(source_control_points)[0] batch_size = paddle.shape(source_control_points)[0]
self.padding_matrix = paddle.expand( padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2]) self.padding_matrix, shape=[batch_size, 3, 2])
Y = paddle.concat([source_control_points, self.padding_matrix], 1) Y = paddle.concat([source_control_points, padding_matrix], 1)
mapping_matrix = paddle.matmul(self.inverse_kernel, Y) mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
source_coordinate = paddle.matmul(self.target_coordinate_repr, source_coordinate = paddle.matmul(self.target_coordinate_repr,
mapping_matrix) mapping_matrix)
......
...@@ -30,7 +30,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -30,7 +30,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr return lr
def build_optimizer(config, epochs, step_each_epoch, parameters): def build_optimizer(config, epochs, step_each_epoch, model):
from . import regularizer, optimizer from . import regularizer, optimizer
config = copy.deepcopy(config) config = copy.deepcopy(config)
# step1 build lr # step1 build lr
...@@ -43,6 +43,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -43,6 +43,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
if not hasattr(regularizer, reg_name): if not hasattr(regularizer, reg_name):
reg_name += 'Decay' reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)() reg = getattr(regularizer, reg_name)(**reg_config)()
elif 'weight_decay' in config:
reg = config.pop('weight_decay')
else: else:
reg = None reg = None
...@@ -57,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -57,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay=reg, weight_decay=reg,
grad_clip=grad_clip, grad_clip=grad_clip,
**config) **config)
return optim(parameters), lr return optim(model), lr
...@@ -42,13 +42,13 @@ class Momentum(object): ...@@ -42,13 +42,13 @@ class Momentum(object):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, model):
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -75,7 +75,7 @@ class Adam(object): ...@@ -75,7 +75,7 @@ class Adam(object):
self.name = name self.name = name
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
def __call__(self, parameters): def __call__(self, model):
opt = optim.Adam( opt = optim.Adam(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -85,7 +85,7 @@ class Adam(object): ...@@ -85,7 +85,7 @@ class Adam(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
lazy_mode=self.lazy_mode, lazy_mode=self.lazy_mode,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -117,7 +117,7 @@ class RMSProp(object): ...@@ -117,7 +117,7 @@ class RMSProp(object):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, model):
opt = optim.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -125,7 +125,7 @@ class RMSProp(object): ...@@ -125,7 +125,7 @@ class RMSProp(object):
epsilon=self.epsilon, epsilon=self.epsilon,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -148,7 +148,7 @@ class Adadelta(object): ...@@ -148,7 +148,7 @@ class Adadelta(object):
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.name = name self.name = name
def __call__(self, parameters): def __call__(self, model):
opt = optim.Adadelta( opt = optim.Adadelta(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
epsilon=self.epsilon, epsilon=self.epsilon,
...@@ -156,7 +156,7 @@ class Adadelta(object): ...@@ -156,7 +156,7 @@ class Adadelta(object):
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -165,31 +165,55 @@ class AdamW(object): ...@@ -165,31 +165,55 @@ class AdamW(object):
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-8,
weight_decay=0.01, weight_decay=0.01,
multi_precision=False,
grad_clip=None, grad_clip=None,
no_weight_decay_name=None,
one_dim_param_no_weight_decay=False,
name=None, name=None,
lazy_mode=False, lazy_mode=False,
**kwargs): **args):
super().__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.beta1 = beta1 self.beta1 = beta1
self.beta2 = beta2 self.beta2 = beta2
self.epsilon = epsilon self.epsilon = epsilon
self.learning_rate = learning_rate self.grad_clip = grad_clip
self.weight_decay = 0.01 if weight_decay is None else weight_decay self.weight_decay = 0.01 if weight_decay is None else weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.name = name self.name = name
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
self.multi_precision = multi_precision
def __call__(self, parameters): self.no_weight_decay_name_list = no_weight_decay_name.split(
) if no_weight_decay_name else []
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model):
parameters = model.parameters()
self.no_weight_decay_param_name_list = [
p.name for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list)
]
if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [
p.name for n, p in model.named_parameters() if len(p.shape) == 1
]
opt = optim.AdamW( opt = optim.AdamW(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
beta2=self.beta2, beta2=self.beta2,
epsilon=self.epsilon, epsilon=self.epsilon,
parameters=parameters,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
multi_precision=self.multi_precision,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
lazy_mode=self.lazy_mode, lazy_mode=self.lazy_mode,
parameters=parameters) apply_decay_param_fun=self._apply_decay_param_fun)
return opt return opt
def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list
\ No newline at end of file
...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess ...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None): ...@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None):
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode' 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'SVTRLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
use_space_char=False, use_space_char=False,
model_name=["student"], model_name=["student"],
key=None, key=None,
multi_head=False,
**kwargs): **kwargs):
super(DistillationCTCLabelDecode, self).__init__(character_dict_path, super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
use_space_char) use_space_char)
...@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
self.multi_head = multi_head
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
output = dict() output = dict()
...@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
pred = preds[name] pred = preds[name]
if self.key is not None: if self.key is not None:
pred = pred[self.key] pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
pred = pred['ctc']
output[name] = super().__call__(pred, label=label, *args, **kwargs) output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output return output
...@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode):
return [self.padding_idx] return [self.padding_idx]
class DistillationSARLabelDecode(SARLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
use_space_char=False,
model_name=["student"],
key=None,
multi_head=False,
**kwargs):
super(DistillationSARLabelDecode, self).__init__(character_dict_path,
use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
self.multi_head = multi_head
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
pred = pred['sar']
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class PRENLabelDecode(BaseRecLabelDecode): class PRENLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
...@@ -714,3 +752,40 @@ class PRENLabelDecode(BaseRecLabelDecode): ...@@ -714,3 +752,40 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text return text
label = self.decode(label) label = self.decode(label)
return text, label return text, label
class SVTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SVTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=-1)
preds_prob = preds.max(axis=-1)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return_text = []
for i in range(0, len(text), 3):
text0 = text[i]
text1 = text[i + 1]
text2 = text[i + 2]
text_pred = [text0[0], text1[0], text2[0]]
text_prob = [text0[1], text1[1], text2[1]]
id_max = text_prob.index(max(text_prob))
return_text.append((text_pred[id_max], text_prob[id_max]))
if label is None:
return return_text
label = self.decode(label)
return return_text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
\ No newline at end of file
...@@ -21,7 +21,7 @@ l ...@@ -21,7 +21,7 @@ l
8 8
. .
j j
p p
......
...@@ -22,7 +22,7 @@ l ...@@ -22,7 +22,7 @@ l
8 8
. .
j j
p p
......
===========================train_params===========================
model_name:ch_PPOCRv2_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:amp
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
quant_export:null
fpgm_export:
distill_export:null
export1:null
export2:null
inference_dir:Student
infer_model:./inference/ch_PP-OCRv2_det_infer/
infer_export:null
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_params===========================
model_name:ch_PPOCRv2_det_PACT
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:amp
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:pact_train
norm_train:null
pact_train:deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
fpgm_export:
distill_export:null
export1:null
export2:null
inference_dir:Student
infer_model:./inference/ch_PP-OCRv2_det_infer/
infer_export:null
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment