Unverified Commit f6532a0e authored by andyjpaddle's avatar andyjpaddle Committed by GitHub
Browse files

add ppocrv3 rec (#6033)

* add ppocrv3 rec
parent 6902d160
Global:
debug: false
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v3
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: rec
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- SARLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
ignore_space: True
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
- RecAug:
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v3_distillation
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.0005, 0.00005]
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
multi_head: True
dis_head: ctc
name: dml_ctc
- DistillationDMLLoss:
weight: 0.5
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
multi_head: True
dis_head: sar
name: dml_sar
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
- DistillationSARLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
multi_head: True
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
ignore_space: True
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
- RecAug:
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
......@@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, \
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -22,6 +22,7 @@ import numpy as np
import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
from ppocr.utils.logging import get_logger
......@@ -1007,3 +1008,34 @@ class VQATokenLabelEncode(object):
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
return gt_label
class MultiLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(MultiLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_out = dict()
data_out['img_path'] = data.get('img_path', None)
data_out['image'] = data['image']
ctc = self.ctc_encode.__call__(data_ctc)
sar = self.sar_encode.__call__(data_sar)
if ctc is None or sar is None:
return None
data_out['label_ctc'] = ctc['label']
data_out['label_sar'] = sar['label']
data_out['length'] = ctc['length']
return data_out
......@@ -32,6 +32,49 @@ class RecAug(object):
return data
class RecConAug(object):
def __init__(self,
prob=0.5,
image_shape=(32, 320, 3),
max_text_length=25,
ext_data_num=1,
**kwargs):
self.ext_data_num = ext_data_num
self.prob = prob
self.max_text_length = max_text_length
self.image_shape = image_shape
self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
def merge_ext_data(self, data, ext_data):
ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
self.image_shape[0])
ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
self.image_shape[0])
data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
ext_data['image'] = cv2.resize(ext_data['image'],
(ext_w, self.image_shape[0]))
data['image'] = np.concatenate(
[data['image'], ext_data['image']], axis=1)
data["label"] += ext_data["label"]
return data
def __call__(self, data):
rnd_num = random.random()
if rnd_num > self.prob:
return data
for idx, ext_data in enumerate(data["ext_data"]):
if len(data["label"]) + len(ext_data[
"label"]) > self.max_text_length:
break
concat_ratio = data['image'].shape[1] / data['image'].shape[
0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
if concat_ratio > self.max_wh_ratio:
break
data = self.merge_ext_data(data, ext_data)
data.pop("ext_data")
return data
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
......@@ -98,10 +141,13 @@ class RecResizeImg(object):
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_dict_path is not None:
norm_img = resize_norm_img_chinese(img, self.image_shape)
norm_img, valid_ratio = resize_norm_img_chinese(img,
self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape, self.padding)
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
self.padding)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
......@@ -220,7 +266,8 @@ def resize_norm_img(img, image_shape, padding=True):
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
valid_ratio = min(1.0, float(resized_w / imgW))
return padding_im, valid_ratio
def resize_norm_img_chinese(img, image_shape):
......@@ -230,7 +277,7 @@ def resize_norm_img_chinese(img, image_shape):
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
imgW = int(imgH * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
......@@ -246,7 +293,8 @@ def resize_norm_img_chinese(img, image_shape):
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
valid_ratio = min(1.0, float(resized_w / imgW))
return padding_im, valid_ratio
def resize_norm_img_srn(img, image_shape):
......
......@@ -49,7 +49,8 @@ class SimpleDataSet(Dataset):
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
2)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
......@@ -87,7 +88,7 @@ class SimpleDataSet(Dataset):
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:2]
load_data_ops = self.ops[:self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
......@@ -108,7 +109,10 @@ class SimpleDataSet(Dataset):
data['image'] = img
data = transform(data, load_data_ops)
if data is None or data['polys'].shape[1] != 4:
if data is None:
continue
if 'polys' in data.keys():
if data['polys'].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
......
......@@ -34,6 +34,7 @@ from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -60,7 +61,7 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss'
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -106,8 +106,8 @@ class DMLLoss(nn.Layer):
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
out1 = self.act(out1) + 1e-10
out2 = self.act(out2) + 1e-10
if self.use_log:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
......
......@@ -18,8 +18,10 @@ import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .center_loss import CenterLoss
from .ace_loss import ACELoss
from .rec_sar_loss import SARLoss
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
......
......@@ -18,6 +18,7 @@ import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .det_db_loss import DBLoss
......@@ -46,11 +47,15 @@ class DistillationDMLLoss(DMLLoss):
act=None,
use_log=False,
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="dml"):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
......@@ -97,6 +102,10 @@ class DistillationDMLLoss(DMLLoss):
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
loss = super().forward(out1[self.dis_head],
out2[self.dis_head])
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
......@@ -123,11 +132,50 @@ class DistillationDMLLoss(DMLLoss):
class DistillationCTCLoss(CTCLoss):
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'ctc' in out, 'multi head has multi out'
loss = super().forward(out['ctc'], batch[:2] + batch[3:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationSARLoss(SARLoss):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
name="loss_sar",
**kwargs):
ignore_index = kwargs.get('ignore_index', 92)
super().__init__(ignore_index=ignore_index)
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
......@@ -135,6 +183,10 @@ class DistillationCTCLoss(CTCLoss):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'sar' in out, 'multi head has multi out'
loss = super().forward(out['sar'], batch[:1] + batch[2:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
class MultiLoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_funcs = {}
self.loss_list = kwargs.pop('loss_config_list')
self.weight_1 = kwargs.get('weight_1', 1.0)
self.weight_2 = kwargs.get('weight_2', 1.0)
self.gtc_loss = kwargs.get('gtc_loss', 'sar')
for loss_info in self.loss_list:
for name, param in loss_info.items():
if param is not None:
kwargs.update(param)
loss = eval(name)(**kwargs)
self.loss_funcs[name] = loss
def forward(self, predicts, batch):
self.total_loss = {}
total_loss = 0.0
# batch [image, label_ctc, label_sar, length, valid_ratio]
for name, loss_func in self.loss_funcs.items():
if name == 'CTCLoss':
loss = loss_func(predicts['ctc'],
batch[:2] + batch[3:])['loss'] * self.weight_1
elif name == 'SARLoss':
loss = loss_func(predicts['sar'],
batch[:1] + batch[2:])['loss'] * self.weight_2
else:
raise NotImplementedError(
'{} is not supported in MultiLoss yet'.format(name))
self.total_loss[name] = loss
total_loss += loss
self.total_loss['loss'] = total_loss
return self.total_loss
......@@ -9,8 +9,9 @@ from paddle import nn
class SARLoss(nn.Layer):
def __init__(self, **kwargs):
super(SARLoss, self).__init__()
ignore_index = kwargs.get('ignore_index', 92) # 6626
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
reduction="mean", ignore_index=92)
reduction="mean", ignore_index=ignore_index)
def forward(self, predicts, batch):
predict = predicts[:, :
......
......@@ -17,9 +17,14 @@ import string
class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
def __init__(self,
main_indicator='acc',
is_filter=False,
ignore_space=True,
**kwargs):
self.main_indicator = main_indicator
self.is_filter = is_filter
self.ignore_space = ignore_space
self.eps = 1e-5
self.reset()
......@@ -34,6 +39,7 @@ class RecMetric(object):
all_num = 0
norm_edit_dis = 0.0
for (pred, pred_conf), (target, _) in zip(preds, labels):
if self.ignore_space:
pred = pred.replace(" ", "")
target = target.replace(" ", "")
if self.is_filter:
......
......@@ -83,7 +83,11 @@ class BaseModel(nn.Layer):
y["neck_out"] = x
if self.use_head:
x = self.head(x, targets=data)
if isinstance(x, dict):
# for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
......
......@@ -53,8 +53,8 @@ class DistillationModel(nn.Layer):
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x):
def forward(self, x, data=None):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x)
result_dict[model_name] = self.model_list[idx](x, data)
return result_dict
......@@ -31,9 +31,11 @@ def build_backbone(config, model_type):
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN'
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
'SVTRNet'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
......
......@@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer):
class MobileNetV1Enhance(nn.Layer):
def __init__(self, in_channels=3, scale=0.5, **kwargs):
def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__()
self.scale = scale
self.block_list = []
......@@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer):
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
......@@ -208,7 +213,9 @@ class MobileNetV1Enhance(nn.Layer):
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
......
This diff is collapsed.
......@@ -32,6 +32,7 @@ def build_head(config):
from .rec_sar_head import SARHead
from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
# cls head
from .cls_head import ClsHead
......@@ -44,7 +45,8 @@ def build_head(config):
support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead'
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead'
]
#table head
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR
from .rec_ctc_head import CTCHead
from .rec_sar_head import SARHead
class MultiHead(nn.Layer):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop('head_list')
self.gtc_head = 'sar'
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
name = list(head_name)[0]
if name == 'SARHead':
# sar head
sar_args = self.head_list[idx][name]
self.sar_head = eval(name)(in_channels=in_channels, \
out_channels=out_channels_list['SARLabelDecode'], **sar_args)
elif name == 'CTCHead':
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
neck_args = self.head_list[idx][name]['Neck']
encoder_type = neck_args.pop('name')
self.encoder = encoder_type
self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \
encoder_type=encoder_type, **neck_args)
# ctc head
head_args = self.head_list[idx][name]['Head']
self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \
out_channels=out_channels_list['CTCLabelDecode'], **head_args)
else:
raise NotImplementedError(
'{} is not supported in MultiHead yet'.format(name))
def forward(self, x, targets=None):
ctc_encoder = self.ctc_encoder(x)
ctc_out = self.ctc_head(ctc_encoder, targets)
head_out = dict()
head_out['ctc'] = ctc_out
head_out['ctc_neck'] = ctc_encoder
# eval mode
if not self.training:
return ctc_out
if self.gtc_head == 'sar':
sar_out = self.sar_head(x, targets[1:])
head_out['sar'] = sar_out
return head_out
else:
return head_out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment