Commit a5530565 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents a9d5349c 37eec4d5
......@@ -234,6 +234,9 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
| rec_resnet_stn_bilstm_att.yml | SEED | Aster_Resnet | STN | BiLSTM | att |
*其中SEED模型需要额外加载FastText训练好的[语言模型](https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz)
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
......@@ -460,5 +463,3 @@ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_trai
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
```
......@@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
label = [0] * len(self.character)
for x in text:
label[x] += 1
data['label_ace'] = np.array(label)
return data
def add_special_char(self, dict_character):
......@@ -342,6 +347,38 @@ class AttnLabelEncode(BaseRecLabelEncode):
return idx
class SEEDLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclude eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
)
data['label'] = np.array(text)
return data
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
......@@ -421,7 +458,6 @@ class TableLabelEncode(object):
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
......
......@@ -23,6 +23,7 @@ import sys
import six
import cv2
import numpy as np
import fasttext
class DecodeImage(object):
......@@ -83,12 +84,13 @@ class NRTRDecodeImage(object):
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
......@@ -133,6 +135,17 @@ class ToCHWImage(object):
return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
......
......@@ -88,17 +88,19 @@ class RecResizeImg(object):
image_shape,
infer_mode=False,
character_type='ch',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_type = character_type
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_type == "ch":
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape)
norm_img = resize_norm_img(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
......@@ -174,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape):
def resize_norm_img(img, image_shape, padding=True):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
......
......@@ -28,6 +28,8 @@ from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -48,9 +50,8 @@ def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss'
'TableAttentionLoss', 'SARLoss', 'AsterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
class ACELoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_func = nn.CrossEntropyLoss(
weight=None,
ignore_index=0,
reduction='none',
soft_label=True,
axis=-1)
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
predicts = nn.functional.softmax(predicts, axis=-1)
aggregation_preds = paddle.sum(predicts, axis=1)
aggregation_preds = paddle.divide(aggregation_preds, div)
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss}
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def __init__(self,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
if init_center:
assert os.path.exists(
center_file_path
), f"center path({center_file_path}) must exist when init_center is set as True."
with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f)
for key in char_dict.keys():
self.centers[key] = paddle.to_tensor(char_dict[key])
def __call__(self, predicts, batch):
assert isinstance(predicts, (list, tuple))
features, predicts = predicts
feats_reshape = paddle.reshape(
features, [-1, features.shape[-1]]).astype("float64")
label = paddle.argmax(predicts, axis=2)
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
batch_size = feats_reshape.shape[0]
#calc feat * feat
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
label = paddle.expand(
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64") #get mask
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}
......@@ -15,6 +15,10 @@
import paddle
import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .center_loss import CenterLoss
from .ace_loss import ACELoss
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
......
......@@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key]
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class CosineEmbeddingLoss(nn.Layer):
def __init__(self, margin=0.):
super(CosineEmbeddingLoss, self).__init__()
self.margin = margin
self.epsilon = 1e-12
def forward(self, x1, x2, target):
similarity = paddle.fluid.layers.reduce_sum(
x1 * x2, dim=-1) / (paddle.norm(
x1, axis=-1) * paddle.norm(
x2, axis=-1) + self.epsilon)
one_list = paddle.full_like(target, fill_value=1)
out = paddle.fluid.layers.reduce_mean(
paddle.where(
paddle.equal(target, one_list), 1. - similarity,
paddle.maximum(
paddle.zeros_like(similarity), similarity - self.margin)))
return out
class AsterLoss(nn.Layer):
def __init__(self,
weight=None,
size_average=True,
ignore_index=-100,
sequence_normalize=False,
sample_normalize=True,
**kwargs):
super(AsterLoss, self).__init__()
self.weight = weight
self.size_average = size_average
self.ignore_index = ignore_index
self.sequence_normalize = sequence_normalize
self.sample_normalize = sample_normalize
self.loss_sem = CosineEmbeddingLoss()
self.is_cosin_loss = True
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
sem_target = batch[3].astype('float32')
embedding_vectors = predicts['embedding_vectors']
rec_pred = predicts['rec_pred']
if not self.is_cosin_loss:
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
else:
label_target = paddle.ones([embedding_vectors.shape[0]])
sem_loss = paddle.sum(
self.loss_sem(embedding_vectors, sem_target, label_target))
# rec loss
batch_size, def_max_length = targets.shape[0], targets.shape[1]
mask = paddle.zeros([batch_size, def_max_length])
for i in range(batch_size):
mask[i, :label_lengths[i]] = 1
mask = paddle.cast(mask, "float32")
max_length = max(label_lengths)
assert max_length == rec_pred.shape[1]
targets = targets[:, :max_length]
mask = mask[:, :max_length]
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
input = nn.functional.log_softmax(rec_pred, axis=1)
targets = paddle.reshape(targets, [-1, 1])
mask = paddle.reshape(mask, [-1, 1])
output = -paddle.index_sample(input, index=targets) * mask
output = paddle.sum(output)
if self.sequence_normalize:
output = output / paddle.sum(mask)
if self.sample_normalize:
output = output / batch_size
loss = output + sem_loss * 0.1
return {'loss': loss}
......@@ -21,16 +21,24 @@ from paddle import nn
class CTCLoss(nn.Layer):
def __init__(self, **kwargs):
def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
self.use_focal_loss = use_focal_loss
def forward(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
loss = loss.mean() # sum
if self.use_focal_loss:
weight = paddle.exp(-loss)
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
weight = paddle.square(weight)
loss = paddle.multiply(loss, weight)
loss = loss.mean()
return {'loss': loss}
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .ace_loss import ACELoss
from .center_loss import CenterLoss
from .rec_ctc_loss import CTCLoss
class EnhancedCTCLoss(nn.Layer):
def __init__(self,
use_focal_loss=False,
use_ace_loss=False,
ace_loss_weight=0.1,
use_center_loss=False,
center_loss_weight=0.05,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None,
**kwargs):
super(EnhancedCTCLoss, self).__init__()
self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
self.use_ace_loss = False
if use_ace_loss:
self.use_ace_loss = use_ace_loss
self.ace_loss_func = ACELoss()
self.ace_loss_weight = ace_loss_weight
self.use_center_loss = False
if use_center_loss:
self.use_center_loss = use_center_loss
self.center_loss_func = CenterLoss(
num_classes=num_classes,
feat_dim=feat_dim,
init_center=init_center,
center_file_path=center_file_path)
self.center_loss_weight = center_loss_weight
def __call__(self, predicts, batch):
loss = self.ctc_loss_func(predicts, batch)["loss"]
if self.use_center_loss:
center_loss = self.center_loss_func(
predicts, batch)["loss_center"] * self.center_loss_weight
loss = loss + center_loss
if self.use_ace_loss:
ace_loss = self.ace_loss_func(
predicts, batch)["loss_ace"] * self.ace_loss_weight
loss = loss + ace_loss
return {'enhanced_ctc_loss': loss}
......@@ -9,11 +9,14 @@ from paddle import nn
class SARLoss(nn.Layer):
def __init__(self, **kwargs):
super(SARLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
reduction="mean", ignore_index=92)
def forward(self, predicts, batch):
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation
predict = predicts[:, :
-1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype(
"int64")[:, 1:] # ignore first index of target in loss calculation
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
1], predict.shape[2]
assert len(label.shape) == len(list(predict.shape)) - 1, \
......
......@@ -13,13 +13,20 @@
# limitations under the License.
import Levenshtein
import string
class RecMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator
self.is_filter = is_filter
self.reset()
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters), text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
......@@ -28,6 +35,9 @@ class RecMetric(object):
for (pred, pred_conf), (target, _) in zip(preds, labels):
pred = pred.replace(" ", "")
target = target.replace(" ", "")
if self.is_filter:
pred = self._normalize_text(pred)
target = self._normalize_text(target)
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1)
if pred == target:
......@@ -57,4 +67,3 @@ class RecMetric(object):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
......@@ -28,8 +28,10 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31"
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import sys
import math
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2D(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2D(
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
# [n_position]
positions = paddle.arange(0, n_position)
# [feat_dim]
dim_range = paddle.arange(0, feat_dim)
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
# [n_position, feat_dim]
angles = paddle.unsqueeze(
positions, axis=1) / paddle.unsqueeze(
dim_range, axis=0)
angles = paddle.cast(angles, "float32")
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
return angles
class AsterBlock(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(AsterBlock, self).__init__()
self.conv1 = conv1x1(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet_ASTER(nn.Layer):
"""For aster or crnn"""
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
super(ResNet_ASTER, self).__init__()
self.with_lstm = with_lstm
self.n_group = n_group
self.layer0 = nn.Sequential(
nn.Conv2D(
in_channels,
32,
kernel_size=(3, 3),
stride=1,
padding=1,
bias_attr=False),
nn.BatchNorm2D(32),
nn.ReLU())
self.inplanes = 32
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
if with_lstm:
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
self.out_channels = 2 * 256
else:
self.out_channels = 512
def _make_layer(self, planes, blocks, stride):
downsample = None
if stride != [1, 1] or self.inplanes != planes:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
layers = []
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(AsterBlock(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x0 = self.layer0(x)
x1 = self.layer1(x0)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
cnn_feat = x5.squeeze(2) # [N, c, w]
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
if self.with_lstm:
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
return cnn_feat
......@@ -29,13 +29,14 @@ def build_head(config):
from .rec_srn_head import SRNHead
from .rec_nrtr_head import Transformer
from .rec_sar_head import SARHead
from .rec_aster_head import AsterHead
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead'
'TableAttentionHead', 'SARHead', 'AsterHead'
]
#table head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment