Commit cbba27b4 authored by myhloli's avatar myhloli
Browse files

refactor: reorganize project structure and update import paths

parent 3027c677
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import torch
import torch.nn.functional as F
from torch import nn
from ..common import Activation
NET_CONFIG_det = {
"blocks2":
# k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
"blocks5": [
[3, 128, 256, 2, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
],
"blocks6": [
[5, 256, 512, 2, True],
[5, 512, 512, 1, True],
[5, 512, 512, 1, False],
[5, 512, 512, 1, False],
],
}
NET_CONFIG_rec = {
"blocks2":
# k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
"blocks5": [
[3, 128, 256, (1, 2), False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
],
"blocks6": [
[5, 256, 512, (2, 1), True],
[5, 512, 512, 1, True],
[5, 512, 512, (2, 1), False],
[5, 512, 512, 1, False],
],
}
def make_divisible(v, divisor=16, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class LearnableAffineBlock(nn.Module):
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
super().__init__()
self.scale = nn.Parameter(torch.Tensor([scale_value]))
self.bias = nn.Parameter(torch.Tensor([bias_value]))
def forward(self, x):
return self.scale * x + self.bias
class ConvBNLayer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(
out_channels,
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Act(nn.Module):
def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
super().__init__()
if act == "hswish":
self.act = nn.Hardswish(inplace=True)
else:
assert act == "relu"
self.act = Activation(act)
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
def forward(self, x):
return self.lab(self.act(x))
class LearnableRepLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
num_conv_branches=1,
lr_mult=1.0,
lab_lr=0.1,
):
super().__init__()
self.is_repped = False
self.groups = groups
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
self.padding = (kernel_size - 1) // 2
self.identity = (
nn.BatchNorm2d(
num_features=in_channels,
)
if out_channels == in_channels and stride == 1
else None
)
self.conv_kxk = nn.ModuleList(
[
ConvBNLayer(
in_channels,
out_channels,
kernel_size,
stride,
groups=groups,
lr_mult=lr_mult,
)
for _ in range(self.num_conv_branches)
]
)
self.conv_1x1 = (
ConvBNLayer(
in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
)
if kernel_size > 1
else None
)
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
def forward(self, x):
# for export
if self.is_repped:
out = self.lab(self.reparam_conv(x))
if self.stride != 2:
out = self.act(out)
return out
out = 0
if self.identity is not None:
out += self.identity(x)
if self.conv_1x1 is not None:
out += self.conv_1x1(x)
for conv in self.conv_kxk:
out += conv(x)
out = self.lab(out)
if self.stride != 2:
out = self.act(out)
return out
def rep(self):
if self.is_repped:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.groups,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
self.is_repped = True
def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
if not isinstance(kernel1x1, torch.Tensor):
return 0
else:
return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
def _get_kernel_bias(self):
kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(
kernel_conv_1x1, self.kernel_size // 2
)
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
kernel_conv_kxk = 0
bias_conv_kxk = 0
for conv in self.conv_kxk:
kernel, bias = self._fuse_bn_tensor(conv)
kernel_conv_kxk += kernel
bias_conv_kxk += bias
kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
return kernel_reparam, bias_reparam
def _fuse_bn_tensor(self, branch):
if not branch:
return 0, 0
elif isinstance(branch, ConvBNLayer):
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros(
(self.in_channels, input_dim, self.kernel_size, self.kernel_size),
dtype=branch.weight.dtype,
)
for i in range(self.in_channels):
kernel_value[
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch._mean
running_var = branch._variance
gamma = branch.weight
beta = branch.bias
eps = branch._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
class SELayer(nn.Module):
def __init__(self, channel, reduction=4, lr_mult=1.0):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
)
self.hardsigmoid = nn.Hardsigmoid(inplace=True)
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = identity * x
return x
class LCNetV3Block(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride,
dw_size,
use_se=False,
conv_kxk_num=4,
lr_mult=1.0,
lab_lr=0.1,
):
super().__init__()
self.use_se = use_se
self.dw_conv = LearnableRepLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=dw_size,
stride=stride,
groups=in_channels,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
lab_lr=lab_lr,
)
if use_se:
self.se = SELayer(in_channels, lr_mult=lr_mult)
self.pw_conv = LearnableRepLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
lab_lr=lab_lr,
)
def forward(self, x):
x = self.dw_conv(x)
if self.use_se:
x = self.se(x)
x = self.pw_conv(x)
return x
class PPLCNetV3(nn.Module):
def __init__(
self,
scale=1.0,
conv_kxk_num=4,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lab_lr=0.1,
det=False,
**kwargs
):
super().__init__()
self.scale = scale
self.lr_mult_list = lr_mult_list
self.det = det
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
assert isinstance(
self.lr_mult_list, (list, tuple)
), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list)
)
assert (
len(self.lr_mult_list) == 6
), "lr_mult_list length should be 6 but got {}".format(len(self.lr_mult_list))
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=make_divisible(16 * scale),
kernel_size=3,
stride=2,
lr_mult=self.lr_mult_list[0],
)
self.blocks2 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[1],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks2"])
]
)
self.blocks3 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[2],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks3"])
]
)
self.blocks4 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[3],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks4"])
]
)
self.blocks5 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[4],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks5"])
]
)
self.blocks6 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[5],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks6"])
]
)
self.out_channels = make_divisible(512 * scale)
if self.det:
mv_c = [16, 24, 56, 480]
self.out_channels = [
make_divisible(self.net_config["blocks3"][-1][2] * scale),
make_divisible(self.net_config["blocks4"][-1][2] * scale),
make_divisible(self.net_config["blocks5"][-1][2] * scale),
make_divisible(self.net_config["blocks6"][-1][2] * scale),
]
self.layer_list = nn.ModuleList(
[
nn.Conv2d(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
nn.Conv2d(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
nn.Conv2d(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
nn.Conv2d(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
]
)
self.out_channels = [
int(mv_c[0] * scale),
int(mv_c[1] * scale),
int(mv_c[2] * scale),
int(mv_c[3] * scale),
]
def forward(self, x):
out_list = []
x = self.conv1(x)
x = self.blocks2(x)
x = self.blocks3(x)
out_list.append(x)
x = self.blocks4(x)
out_list.append(x)
x = self.blocks5(x)
out_list.append(x)
x = self.blocks6(x)
out_list.append(x)
if self.det:
out_list[0] = self.layer_list[0](out_list[0])
out_list[1] = self.layer_list[1](out_list[1])
out_list[2] = self.layer_list[2](out_list[2])
out_list[3] = self.layer_list[3](out_list[3])
return out_list
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
from torch import nn
from .det_mobilenet_v3 import ConvBNLayer, ResidualUnit, make_divisible
class MobileNetV3(nn.Module):
def __init__(
self,
in_channels=3,
model_name="small",
scale=0.5,
large_stride=None,
small_stride=None,
**kwargs
):
super(MobileNetV3, self).__init__()
if small_stride is None:
small_stride = [2, 2, 2, 2]
if large_stride is None:
large_stride = [1, 2, 2, 2]
assert isinstance(
large_stride, list
), "large_stride type must " "be list but got {}".format(type(large_stride))
assert isinstance(
small_stride, list
), "small_stride type must " "be list but got {}".format(type(small_stride))
assert (
len(large_stride) == 4
), "large_stride length must be " "4 but got {}".format(len(large_stride))
assert (
len(small_stride) == 4
), "small_stride length must be " "4 but got {}".format(len(small_stride))
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", large_stride[0]],
[3, 64, 24, False, "relu", (large_stride[1], 1)],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", (large_stride[2], 1)],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hard_swish", 1],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1],
[5, 672, 160, True, "hard_swish", (large_stride[3], 1)],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", (small_stride[0], 1)],
[3, 72, 24, False, "relu", (small_stride[1], 1)],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hard_swish", (small_stride[2], 1)],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1],
[5, 288, 96, True, "hard_swish", (small_stride[3], 1)],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError(
"mode[" + model_name + "_model] is not implemented!"
)
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert (
scale in supported_scale
), "supported scales are {} but input scale is {}".format(
supported_scale, scale
)
inplanes = 16
# conv1
self.conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act="hard_swish",
name="conv1",
)
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
for k, exp, c, se, nl, s in cfg:
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2),
)
)
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act="hard_swish",
name="conv_last",
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
x = self.pool(x)
return x
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..common import Activation
class ConvBNLayer(nn.Module):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='hard_swish'):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias=False)
self._batch_norm = nn.BatchNorm2d(
num_filters,
)
if self.act is not None:
self._act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class DepthwiseSeparable(nn.Module):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
dw_size=3,
padding=1,
use_se=False):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=dw_size,
stride=stride,
padding=padding,
num_groups=int(num_groups * scale))
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
if self.use_se:
y = self._se(y)
y = self._pointwise_conv(y)
return y
class MobileNetV1Enhance(nn.Module):
def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__()
self.scale = scale
self.block_list = []
self.conv1 = ConvBNLayer(
num_channels=in_channels,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale)
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=1,
scale=scale)
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale)
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=(2, 1),
scale=scale)
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale)
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=(2, 1),
scale=scale)
self.block_list.append(conv4_2)
for _ in range(5):
conv5 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
dw_size=5,
padding=2,
scale=scale,
use_se=False)
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=(2, 1),
dw_size=5,
padding=2,
scale=scale,
use_se=True)
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
scale=scale)
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
def forward(self, inputs):
y = self.conv1(inputs)
y = self.block_list(y)
y = self.pool(y)
return y
def hardsigmoid(x):
return F.relu6(x + 3., inplace=True) / 6.
class SEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.conv2 = nn.Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
bias=True)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs)
x = torch.mul(inputs, outputs)
return x
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveAvgPool2D(nn.AdaptiveAvgPool2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(self.output_size, int) and self.output_size == 1:
self._gap = True
elif (
isinstance(self.output_size, tuple)
and self.output_size[0] == 1
and self.output_size[1] == 1
):
self._gap = True
else:
self._gap = False
def forward(self, x):
if self._gap:
# Global Average Pooling
N, C, _, _ = x.shape
x_mean = torch.mean(x, dim=[2, 3])
x_mean = torch.reshape(x_mean, [N, C, 1, 1])
return x_mean
else:
return F.adaptive_avg_pool2d(
x,
output_size=self.output_size
)
class LearnableAffineBlock(nn.Module):
"""
Create a learnable affine block module. This module can significantly improve accuracy on smaller models.
Args:
scale_value (float): The initial value of the scale parameter, default is 1.0.
bias_value (float): The initial value of the bias parameter, default is 0.0.
lr_mult (float): The learning rate multiplier, default is 1.0.
lab_lr (float): The learning rate, default is 0.01.
"""
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01):
super().__init__()
self.scale = nn.Parameter(torch.Tensor([scale_value]))
self.bias = nn.Parameter(torch.Tensor([bias_value]))
def forward(self, x):
return self.scale * x + self.bias
class ConvBNAct(nn.Module):
"""
ConvBNAct is a combination of convolution and batchnorm layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
stride (int): Stride of the convolution. Defaults to 1.
padding (int/str): Padding or padding type for the convolution. Defaults to 1.
groups (int): Number of groups for the convolution. Defaults to 1.
use_act: (bool): Whether to use activation function. Defaults to True.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=1,
use_act=True,
use_lab=False,
lr_mult=1.0,
):
super().__init__()
self.use_act = use_act
self.use_lab = use_lab
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2,
# padding=(kernel_size - 1) // 2,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(
out_channels,
)
if self.use_act:
self.act = nn.ReLU()
if self.use_lab:
self.lab = LearnableAffineBlock(lr_mult=lr_mult)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
if self.use_lab:
x = self.lab(x)
return x
class LightConvBNAct(nn.Module):
"""
LightConvBNAct is a combination of pw and dw layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the depth-wise convolution kernel.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
use_lab=False,
lr_mult=1.0,
**kwargs,
):
super().__init__()
self.conv1 = ConvBNAct(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
use_act=False,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.conv2 = ConvBNAct(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=out_channels,
use_act=True,
use_lab=use_lab,
lr_mult=lr_mult,
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class CustomMaxPool2d(nn.Module):
def __init__(
self,
kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False,
data_format="NCHW",
):
super(CustomMaxPool2d, self).__init__()
self.kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, kernel_size)
self.stride = stride if stride is not None else self.kernel_size
self.stride = self.stride if isinstance(self.stride, (tuple, list)) else (self.stride, self.stride)
self.dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
self.return_indices = return_indices
self.ceil_mode = ceil_mode
self.padding_mode = padding
# 当padding不是"same"时使用标准MaxPool2d
if padding != "same":
self.padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
self.pool = nn.MaxPool2d(
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
return_indices=self.return_indices,
ceil_mode=self.ceil_mode
)
def forward(self, x):
# 处理same padding
if self.padding_mode == "same":
input_height, input_width = x.size(2), x.size(3)
# 计算期望的输出尺寸
out_height = math.ceil(input_height / self.stride[0])
out_width = math.ceil(input_width / self.stride[1])
# 计算需要的padding
pad_height = max((out_height - 1) * self.stride[0] + self.kernel_size[0] - input_height, 0)
pad_width = max((out_width - 1) * self.stride[1] + self.kernel_size[1] - input_width, 0)
# 将padding分配到两边
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
# 应用padding
x = F.pad(x, (pad_left, pad_right, pad_top, pad_bottom))
# 使用标准max_pool2d函数
if self.return_indices:
return F.max_pool2d_with_indices(
x,
kernel_size=self.kernel_size,
stride=self.stride,
padding=0, # 已经手动pad过了
dilation=self.dilation,
ceil_mode=self.ceil_mode
)
else:
return F.max_pool2d(
x,
kernel_size=self.kernel_size,
stride=self.stride,
padding=0, # 已经手动pad过了
dilation=self.dilation,
ceil_mode=self.ceil_mode
)
else:
# 使用预定义的MaxPool2d
return self.pool(x)
class StemBlock(nn.Module):
"""
StemBlock for PP-HGNetV2.
Args:
in_channels (int): Number of input channels.
mid_channels (int): Number of middle channels.
out_channels (int): Number of output channels.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
mid_channels,
out_channels,
use_lab=False,
lr_mult=1.0,
text_rec=False,
):
super().__init__()
self.stem1 = ConvBNAct(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
stride=2,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem2a = ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels // 2,
kernel_size=2,
stride=1,
padding="same",
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem2b = ConvBNAct(
in_channels=mid_channels // 2,
out_channels=mid_channels,
kernel_size=2,
stride=1,
padding="same",
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem3 = ConvBNAct(
in_channels=mid_channels * 2,
out_channels=mid_channels,
kernel_size=3,
stride=1 if text_rec else 2,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem4 = ConvBNAct(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.pool = CustomMaxPool2d(
kernel_size=2, stride=1, ceil_mode=True, padding="same"
)
# self.pool = nn.MaxPool2d(
# kernel_size=2, stride=1, ceil_mode=True, padding=1
# )
def forward(self, x):
x = self.stem1(x)
x2 = self.stem2a(x)
x2 = self.stem2b(x2)
x1 = self.pool(x)
# if x1.shape[2:] != x2.shape[2:]:
# x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x1, x2], 1)
x = self.stem3(x)
x = self.stem4(x)
return x
class HGV2_Block(nn.Module):
"""
HGV2_Block, the basic unit that constitutes the HGV2_Stage.
Args:
in_channels (int): Number of input channels.
mid_channels (int): Number of middle channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
stride (int): Stride of the convolution. Defaults to 1.
padding (int/str): Padding or padding type for the convolution. Defaults to 1.
groups (int): Number of groups for the convolution. Defaults to 1.
use_act (bool): Whether to use activation function. Defaults to True.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
mid_channels,
out_channels,
kernel_size=3,
layer_num=6,
identity=False,
light_block=True,
use_lab=False,
lr_mult=1.0,
):
super().__init__()
self.identity = identity
self.layers = nn.ModuleList()
block_type = "LightConvBNAct" if light_block else "ConvBNAct"
for i in range(layer_num):
self.layers.append(
eval(block_type)(
in_channels=in_channels if i == 0 else mid_channels,
out_channels=mid_channels,
stride=1,
kernel_size=kernel_size,
use_lab=use_lab,
lr_mult=lr_mult,
)
)
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
self.aggregation_squeeze_conv = ConvBNAct(
in_channels=total_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.aggregation_excitation_conv = ConvBNAct(
in_channels=out_channels // 2,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_lab=use_lab,
lr_mult=lr_mult,
)
def forward(self, x):
identity = x
output = []
output.append(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
x = self.aggregation_squeeze_conv(x)
x = self.aggregation_excitation_conv(x)
if self.identity:
x += identity
return x
class HGV2_Stage(nn.Module):
"""
HGV2_Stage, the basic unit that constitutes the PPHGNetV2.
Args:
in_channels (int): Number of input channels.
mid_channels (int): Number of middle channels.
out_channels (int): Number of output channels.
block_num (int): Number of blocks in the HGV2 stage.
layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
is_downsample (bool): Whether to use downsampling operation. Defaults to False.
light_block (bool): Whether to use light block. Defaults to True.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
use_lab (bool, optional): Whether to use the LAB operation. Defaults to False.
lr_mult (float, optional): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
mid_channels,
out_channels,
block_num,
layer_num=6,
is_downsample=True,
light_block=True,
kernel_size=3,
use_lab=False,
stride=2,
lr_mult=1.0,
):
super().__init__()
self.is_downsample = is_downsample
if self.is_downsample:
self.downsample = ConvBNAct(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
groups=in_channels,
use_act=False,
use_lab=use_lab,
lr_mult=lr_mult,
)
blocks_list = []
for i in range(block_num):
blocks_list.append(
HGV2_Block(
in_channels=in_channels if i == 0 else out_channels,
mid_channels=mid_channels,
out_channels=out_channels,
kernel_size=kernel_size,
layer_num=layer_num,
identity=False if i == 0 else True,
light_block=light_block,
use_lab=use_lab,
lr_mult=lr_mult,
)
)
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
if self.is_downsample:
x = self.downsample(x)
x = self.blocks(x)
return x
class DropoutInferDownscale(nn.Module):
"""
实现与Paddle的mode="downscale_in_infer"等效的Dropout
训练模式:out = input * mask(直接应用掩码,不进行放大)
推理模式:out = input * (1.0 - p)(在推理时按概率缩小)
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, x):
if self.training:
# 训练时:应用随机mask但不放大
return F.dropout(x, self.p, training=True) * (1.0 - self.p)
else:
# 推理时:按照dropout概率缩小输出
return x * (1.0 - self.p)
class PPHGNetV2(nn.Module):
"""
PPHGNetV2
Args:
stage_config (dict): Config for PPHGNetV2 stages. such as the number of channels, stride, etc.
stem_channels: (list): Number of channels of the stem of the PPHGNetV2.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
use_last_conv (bool): Whether to use the last conv layer as the output channel. Defaults to True.
class_expand (int): Number of channels for the last 1x1 convolutional layer.
drop_prob (float): Dropout probability for the last 1x1 convolutional layer. Defaults to 0.0.
class_num (int): The number of classes for the classification layer. Defaults to 1000.
lr_mult_list (list): Learning rate multiplier for the stages. Defaults to [1.0, 1.0, 1.0, 1.0, 1.0].
Returns:
model: nn.Layer. Specific PPHGNetV2 model depends on args.
"""
def __init__(
self,
stage_config,
stem_channels=[3, 32, 64],
use_lab=False,
use_last_conv=True,
class_expand=2048,
dropout_prob=0.0,
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
det=False,
text_rec=False,
out_indices=None,
**kwargs,
):
super().__init__()
self.det = det
self.text_rec = text_rec
self.use_lab = use_lab
self.use_last_conv = use_last_conv
self.class_expand = class_expand
self.class_num = class_num
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
self.out_channels = []
# stem
self.stem = StemBlock(
in_channels=stem_channels[0],
mid_channels=stem_channels[1],
out_channels=stem_channels[2],
use_lab=use_lab,
lr_mult=lr_mult_list[0],
text_rec=text_rec,
)
# stages
self.stages = nn.ModuleList()
for i, k in enumerate(stage_config):
(
in_channels,
mid_channels,
out_channels,
block_num,
is_downsample,
light_block,
kernel_size,
layer_num,
stride,
) = stage_config[k]
self.stages.append(
HGV2_Stage(
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
is_downsample,
light_block,
kernel_size,
use_lab,
stride,
lr_mult=lr_mult_list[i + 1],
)
)
if i in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self.avg_pool = AdaptiveAvgPool2D(1)
if self.use_last_conv:
self.last_conv = nn.Conv2d(
in_channels=out_channels,
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.act = nn.ReLU()
if self.use_lab:
self.lab = LearnableAffineBlock()
self.dropout = DropoutInferDownscale(p=dropout_prob)
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
if not self.det:
self.fc = nn.Linear(
self.class_expand if self.use_last_conv else out_channels,
self.class_num,
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
out = []
for i, stage in enumerate(self.stages):
x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.text_rec:
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
def PPHGNetV2_B0(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B0
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B0` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [16, 16, 64, 1, False, False, 3, 3],
"stage2": [64, 32, 256, 1, True, False, 3, 3],
"stage3": [256, 64, 512, 2, True, True, 5, 3],
"stage4": [512, 128, 1024, 1, True, True, 5, 3],
}
model = PPHGNetV2(
stem_channels=[3, 16, 16], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B1(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B1
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B1` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 64, 1, False, False, 3, 3],
"stage2": [64, 48, 256, 1, True, False, 3, 3],
"stage3": [256, 96, 512, 2, True, True, 5, 3],
"stage4": [512, 192, 1024, 1, True, True, 5, 3],
}
model = PPHGNetV2(
stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B2(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B2
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B2` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 96, 1, False, False, 3, 4],
"stage2": [96, 64, 384, 1, True, False, 3, 4],
"stage3": [384, 128, 768, 3, True, True, 5, 4],
"stage4": [768, 256, 1536, 1, True, True, 5, 4],
}
model = PPHGNetV2(
stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B3
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B3` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 128, 1, False, False, 3, 5],
"stage2": [128, 64, 512, 1, True, False, 3, 5],
"stage3": [512, 128, 1024, 3, True, True, 5, 5],
"stage4": [1024, 256, 2048, 1, True, True, 5, 5],
}
model = PPHGNetV2(
stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
"""
PPHGNetV2_B4
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
"""
stage_config_rec = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
"stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
"stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
}
stage_config_det = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
}
model = PPHGNetV2(
stem_channels=[3, 32, 48],
stage_config=stage_config_det if det else stage_config_rec,
use_lab=False,
det=det,
text_rec=text_rec,
**kwargs,
)
return model
def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B5
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B5` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [64, 64, 128, 1, False, False, 3, 6],
"stage2": [128, 128, 512, 2, True, False, 3, 6],
"stage3": [512, 256, 1024, 5, True, True, 5, 6],
"stage4": [1024, 512, 2048, 2, True, True, 5, 6],
}
model = PPHGNetV2(
stem_channels=[3, 32, 64], stage_config=stage_config, use_lab=False, **kwargs
)
return model
def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B6
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B6` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [96, 96, 192, 2, False, False, 3, 6],
"stage2": [192, 192, 512, 3, True, False, 3, 6],
"stage3": [512, 384, 1024, 6, True, True, 5, 6],
"stage4": [1024, 768, 2048, 3, True, True, 5, 6],
}
model = PPHGNetV2(
stem_channels=[3, 48, 96], stage_config=stage_config, use_lab=False, **kwargs
)
return model
import numpy as np
import torch
from torch import nn
from ..common import Activation
def drop_path(x, drop_prob=0.0, training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = torch.as_tensor(1 - drop_prob)
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
random_tensor = torch.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class ConvBNLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act="gelu",
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias_attr,
)
self.norm = nn.BatchNorm2d(out_channels)
self.act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer="gelu",
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = Activation(act_type=act_layer, inplace=True)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
num_heads=8,
HW=[8, 25],
local_k=[3, 3],
):
super().__init__()
self.HW = HW
self.dim = dim
self.local_mixer = nn.Conv2d(
dim,
dim,
local_k,
1,
[local_k[0] // 2, local_k[1] // 2],
groups=num_heads,
)
def forward(self, x):
h = self.HW[0]
w = self.HW[1]
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
x = self.local_mixer(x)
x = x.flatten(2).permute(0, 2, 1)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
mixer="Global",
HW=[8, 25],
local_k=[7, 11],
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.HW = HW
if HW is not None:
H = HW[0]
W = HW[1]
self.N = H * W
self.C = dim
if mixer == "Local" and HW is not None:
hk = local_k[0]
wk = local_k[1]
mask = torch.ones(H * W, H + hk - 1, W + wk - 1, dtype=torch.float32)
for h in range(0, H):
for w in range(0, W):
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
1
)
mask_inf = torch.full(
[H * W, H * W], fill_value=float("-Inf"), dtype=torch.float32
)
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask.unsqueeze(0).unsqueeze(1)
# self.mask = mask[None, None, :]
self.mixer = mixer
def forward(self, x):
if self.HW is not None:
N = self.N
C = self.C
else:
_, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute(
2, 0, 3, 1, 4
)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q.matmul(k.permute(0, 1, 3, 2))
if self.mixer == "Local":
attn += self.mask
attn = nn.functional.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).permute(0, 2, 1, 3).reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mixer="Global",
local_mixer=[7, 11],
HW=None,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer="gelu",
norm_layer="nn.LayerNorm",
epsilon=1e-6,
prenorm=True,
):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm1 = norm_layer(dim)
if mixer == "Global" or mixer == "Local":
self.mixer = Attention(
dim,
num_heads=num_heads,
mixer=mixer,
HW=HW,
local_k=local_mixer,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
elif mixer == "Conv":
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
else:
raise TypeError("The mixer must be one of [Global, Local, Conv]")
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.prenorm = prenorm
def forward(self, x):
if self.prenorm:
x = self.norm1(x + self.drop_path(self.mixer(x)))
x = self.norm2(x + self.drop_path(self.mlp(x)))
else:
x = x + self.drop_path(self.mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
img_size=[32, 100],
in_channels=3,
embed_dim=768,
sub_num=2,
patch_size=[4, 4],
mode="pope",
):
super().__init__()
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
self.img_size = img_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
if mode == "pope":
if sub_num == 2:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
)
if sub_num == 3:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 4,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
ConvBNLayer(
in_channels=embed_dim // 4,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
)
elif mode == "linear":
self.proj = nn.Conv2d(
1, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.num_patches = (
img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
)
def forward(self, x):
B, C, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), "Input image size ({}*{}) doesn't match model ({}*{}).".format(
H, W, self.img_size[0], self.img_size[1]
)
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
class SubSample(nn.Module):
def __init__(
self,
in_channels,
out_channels,
types="Pool",
stride=[2, 1],
sub_norm="nn.LayerNorm",
act=None,
):
super().__init__()
self.types = types
if types == "Pool":
self.avgpool = nn.AvgPool2d(
kernel_size=[3, 5], stride=stride, padding=[1, 2]
)
self.maxpool = nn.MaxPool2d(
kernel_size=[3, 5], stride=stride, padding=[1, 2]
)
self.proj = nn.Linear(in_channels, out_channels)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
)
self.norm = eval(sub_norm)(out_channels)
if act is not None:
self.act = act()
else:
self.act = None
def forward(self, x):
if self.types == "Pool":
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = (x1 + x2) * 0.5
out = self.proj(x.flatten(2).permute(0, 2, 1))
else:
x = self.conv(x)
out = x.flatten(2).permute(0, 2, 1)
out = self.norm(out)
if self.act is not None:
out = self.act(out)
return out
class SVTRNet(nn.Module):
def __init__(
self,
img_size=[32, 100],
in_channels=3,
embed_dim=[64, 128, 256],
depth=[3, 6, 3],
num_heads=[2, 4, 8],
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
local_mixer=[[7, 11], [7, 11], [7, 11]],
patch_merging="Conv", # Conv, Pool, None
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
last_drop=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer="nn.LayerNorm",
sub_norm="nn.LayerNorm",
epsilon=1e-6,
out_channels=192,
out_char_num=25,
block_unit="Block",
act="gelu",
last_stage=True,
sub_num=2,
prenorm=True,
use_lenhead=False,
**kwargs
):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = out_channels
self.prenorm = prenorm
patch_merging = (
None
if patch_merging != "Conv" and patch_merging != "Pool"
else patch_merging
)
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim[0],
sub_num=sub_num,
)
num_patches = self.patch_embed.num_patches
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
self.pos_drop = nn.Dropout(p=drop_rate)
Block_unit = eval(block_unit)
dpr = np.linspace(0, drop_path_rate, sum(depth))
self.blocks1 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[0],
num_heads=num_heads[0],
mixer=mixer[0 : depth[0]][i],
HW=self.HW,
local_mixer=local_mixer[0],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act,
attn_drop=attn_drop_rate,
drop_path=dpr[0 : depth[0]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[0])
]
)
if patch_merging is not None:
self.sub_sample1 = SubSample(
embed_dim[0],
embed_dim[1],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging,
)
HW = [self.HW[0] // 2, self.HW[1]]
else:
HW = self.HW
self.patch_merging = patch_merging
self.blocks2 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[1],
num_heads=num_heads[1],
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
HW=HW,
local_mixer=local_mixer[1],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act,
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[1])
]
)
if patch_merging is not None:
self.sub_sample2 = SubSample(
embed_dim[1],
embed_dim[2],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging,
)
HW = [self.HW[0] // 4, self.HW[1]]
else:
HW = self.HW
self.blocks3 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[2],
num_heads=num_heads[2],
mixer=mixer[depth[0] + depth[1] :][i],
HW=HW,
local_mixer=local_mixer[2],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act,
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] + depth[1] :][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[2])
]
)
self.last_stage = last_stage
if last_stage:
self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
self.last_conv = nn.Conv2d(
in_channels=embed_dim[2],
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.hardswish = Activation("hard_swish", inplace=True) # nn.Hardswish()
# self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
self.dropout = nn.Dropout(p=last_drop)
if not prenorm:
self.norm = eval(norm_layer)(embed_dim[-1], eps=epsilon)
self.use_lenhead = use_lenhead
if use_lenhead:
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
self.hardswish_len = Activation(
"hard_swish", inplace=True
) # nn.Hardswish()
self.dropout_len = nn.Dropout(p=last_drop)
torch.nn.init.xavier_normal_(self.pos_embed)
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward_features(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample1(
x.permute(0, 2, 1).reshape(
[-1, self.embed_dim[0], self.HW[0], self.HW[1]]
)
)
for blk in self.blocks2:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample2(
x.permute(0, 2, 1).reshape(
[-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
)
)
for blk in self.blocks3:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.use_lenhead:
len_x = self.len_conv(x.mean(1))
len_x = self.dropout_len(self.hardswish_len(len_x))
if self.last_stage:
if self.patch_merging is not None:
h = self.HW[0] // 4
else:
h = self.HW[0]
x = self.avg_pool(
x.permute(0, 2, 1).reshape([-1, self.embed_dim[2], h, self.HW[1]])
)
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
if self.use_lenhead:
return x, len_x
return x
import torch
import torch.nn.functional as F
from torch import nn
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
# out = max(0, min(1, slop*x+offset))
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
class GELU(nn.Module):
def __init__(self, inplace=True):
super(GELU, self).__init__()
self.inplace = inplace
def forward(self, x):
return torch.nn.functional.gelu(x)
class Swish(nn.Module):
def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
if self.inplace:
x.mul_(torch.sigmoid(x))
return x
else:
return x * torch.sigmoid(x)
class Activation(nn.Module):
def __init__(self, act_type, inplace=True):
super(Activation, self).__init__()
act_type = act_type.lower()
if act_type == "relu":
self.act = nn.ReLU(inplace=inplace)
elif act_type == "relu6":
self.act = nn.ReLU6(inplace=inplace)
elif act_type == "sigmoid":
raise NotImplementedError
elif act_type == "hard_sigmoid":
self.act = Hsigmoid(
inplace
) # nn.Hardsigmoid(inplace=inplace)#Hsigmoid(inplace)#
elif act_type == "hard_swish" or act_type == "hswish":
self.act = Hswish(inplace=inplace)
elif act_type == "leakyrelu":
self.act = nn.LeakyReLU(inplace=inplace)
elif act_type == "gelu":
self.act = GELU(inplace=inplace)
elif act_type == "swish":
self.act = Swish(inplace=inplace)
else:
raise NotImplementedError
def forward(self, inputs):
return self.act(inputs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["build_head"]
def build_head(config, **kwargs):
# det head
from .det_db_head import DBHead, PFHeadLocal
# rec head
from .rec_ctc_head import CTCHead
from .rec_multi_head import MultiHead
# cls head
from .cls_head import ClsHead
support_dict = [
"DBHead",
"CTCHead",
"ClsHead",
"MultiHead",
"PFHeadLocal",
]
module_name = config.pop("name")
char_num = config.pop("char_num", 6625)
assert module_name in support_dict, Exception(
"head only support {}".format(support_dict)
)
module_class = eval(module_name)(**config, **kwargs)
return module_class
import torch
import torch.nn.functional as F
from torch import nn
class ClsHead(nn.Module):
"""
Class orientation
Args:
params(dict): super parameters for build Class network
"""
def __init__(self, in_channels, class_dim, **kwargs):
super(ClsHead, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(in_channels, class_dim, bias=True)
def forward(self, x):
x = self.pool(x)
x = torch.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x)
x = F.softmax(x, dim=1)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..common import Activation
from ..backbones.det_mobilenet_v3 import ConvBNLayer
class Head(nn.Module):
def __init__(self, in_channels, **kwargs):
super(Head, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels // 4,
kernel_size=3,
padding=1,
bias=False)
self.conv_bn1 = nn.BatchNorm2d(
in_channels // 4)
self.relu1 = Activation(act_type='relu')
self.conv2 = nn.ConvTranspose2d(
in_channels=in_channels // 4,
out_channels=in_channels // 4,
kernel_size=2,
stride=2)
self.conv_bn2 = nn.BatchNorm2d(
in_channels // 4)
self.relu2 = Activation(act_type='relu')
self.conv3 = nn.ConvTranspose2d(
in_channels=in_channels // 4,
out_channels=1,
kernel_size=2,
stride=2)
def forward(self, x, return_f=False):
x = self.conv1(x)
x = self.conv_bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.conv_bn2(x)
x = self.relu2(x)
if return_f is True:
f = x
x = self.conv3(x)
x = torch.sigmoid(x)
if return_f is True:
return x, f
return x
class DBHead(nn.Module):
"""
Differentiable Binarization (DB) for text detection:
see https://arxiv.org/abs/1911.08947
args:
params(dict): super parameters for build DB network
"""
def __init__(self, in_channels, k=50, **kwargs):
super(DBHead, self).__init__()
self.k = k
binarize_name_list = [
'conv2d_56', 'batch_norm_47', 'conv2d_transpose_0', 'batch_norm_48',
'conv2d_transpose_1', 'binarize'
]
thresh_name_list = [
'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
'conv2d_transpose_3', 'thresh'
]
self.binarize = Head(in_channels, **kwargs)# binarize_name_list)
self.thresh = Head(in_channels, **kwargs)#thresh_name_list)
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
def forward(self, x):
shrink_maps = self.binarize(x)
return {'maps': shrink_maps}
class LocalModule(nn.Module):
def __init__(self, in_c, mid_c, use_distance=True):
super(self.__class__, self).__init__()
self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
self.last_1 = nn.Conv2d(mid_c, 1, 1, 1, 0)
def forward(self, x, init_map, distance_map):
outf = torch.cat([init_map, x], dim=1)
# last Conv
out = self.last_1(self.last_3(outf))
return out
class PFHeadLocal(DBHead):
def __init__(self, in_channels, k=50, mode='small', **kwargs):
super(PFHeadLocal, self).__init__(in_channels, k, **kwargs)
self.mode = mode
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest")
if self.mode == 'large':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
elif self.mode == 'small':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
def forward(self, x, targets=None):
shrink_maps, f = self.binarize(x, return_f=True)
base_maps = shrink_maps
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
cbn_maps = F.sigmoid(cbn_maps)
return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
\ No newline at end of file
import torch.nn.functional as F
from torch import nn
class CTCHead(nn.Module):
def __init__(
self,
in_channels,
out_channels=6625,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs
):
super(CTCHead, self).__init__()
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
bias=True,
)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
bias=True,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
bias=True,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = (x, predicts)
else:
result = predicts
if not self.training:
predicts = F.softmax(predicts, dim=2)
result = predicts
return result
from torch import nn
from ..necks.rnn import Im2Seq, SequenceEncoder
from .rec_ctc_head import CTCHead
class FCTranspose(nn.Module):
def __init__(self, in_channels, out_channels, only_transpose=False):
super().__init__()
self.only_transpose = only_transpose
if not self.only_transpose:
self.fc = nn.Linear(in_channels, out_channels, bias=False)
def forward(self, x):
if self.only_transpose:
return x.permute([0, 2, 1])
else:
return self.fc(x.permute([0, 2, 1]))
class MultiHead(nn.Module):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop("head_list")
self.gtc_head = "sar"
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
name = list(head_name)[0]
if name == "SARHead":
pass
elif name == "NRTRHead":
pass
elif name == "CTCHead":
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
neck_args = self.head_list[idx][name]["Neck"]
encoder_type = neck_args.pop("name")
self.ctc_encoder = SequenceEncoder(
in_channels=in_channels, encoder_type=encoder_type, **neck_args
)
# ctc head
head_args = self.head_list[idx][name].get("Head", {})
if head_args is None:
head_args = {}
self.ctc_head = CTCHead(
in_channels=self.ctc_encoder.out_channels,
out_channels=out_channels_list["CTCLabelDecode"],
**head_args,
)
else:
raise NotImplementedError(f"{name} is not supported in MultiHead yet")
def forward(self, x, data=None):
ctc_encoder = self.ctc_encoder(x)
return self.ctc_head(ctc_encoder)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["build_neck"]
def build_neck(config):
from .db_fpn import DBFPN, LKPAN, RSEFPN
from .rnn import SequenceEncoder
support_dict = ["DBFPN", "SequenceEncoder", "RSEFPN", "LKPAN"]
module_name = config.pop("name")
assert module_name in support_dict, Exception(
"neck only support {}".format(support_dict)
)
module_class = eval(module_name)(**config)
return module_class
import torch
import torch.nn.functional as F
from torch import nn
from ..backbones.det_mobilenet_v3 import SEModule
from ..necks.intracl import IntraCLBlock
def hard_swish(x, inplace=True):
return x * F.relu6(x + 3.0, inplace=inplace) / 6.0
class DSConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding,
stride=1,
groups=None,
if_act=True,
act="relu",
**kwargs
):
super(DSConv, self).__init__()
if groups == None:
groups = in_channels
self.if_act = if_act
self.act = act
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(
in_channels=in_channels,
out_channels=int(in_channels * 4),
kernel_size=1,
stride=1,
bias=False,
)
self.bn2 = nn.BatchNorm2d(int(in_channels * 4))
self.conv3 = nn.Conv2d(
in_channels=int(in_channels * 4),
out_channels=out_channels,
kernel_size=1,
stride=1,
bias=False,
)
self._c = [in_channels, out_channels]
if in_channels != out_channels:
self.conv_end = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
bias=False,
)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = hard_swish(x)
else:
print(
"The activation function({}) is selected incorrectly.".format(
self.act
)
)
exit()
x = self.conv3(x)
if self._c[0] != self._c[1]:
x = x + self.conv_end(inputs)
return x
class DBFPN(nn.Module):
def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__()
self.out_channels = out_channels
self.use_asf = use_asf
self.in2_conv = nn.Conv2d(
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
bias=False,
)
self.in3_conv = nn.Conv2d(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
bias=False,
)
self.in4_conv = nn.Conv2d(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
bias=False,
)
self.in5_conv = nn.Conv2d(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
bias=False,
)
self.p5_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False,
)
self.p4_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False,
)
self.p3_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False,
)
self.p2_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False,
)
if self.use_asf is True:
self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.in5_conv(c5)
in4 = self.in4_conv(c4)
in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2)
out4 = in4 + F.interpolate(
in5,
scale_factor=2,
mode="nearest",
) # align_mode=1) # 1/16
out3 = in3 + F.interpolate(
out4,
scale_factor=2,
mode="nearest",
) # align_mode=1) # 1/8
out2 = in2 + F.interpolate(
out3,
scale_factor=2,
mode="nearest",
) # align_mode=1) # 1/4
p5 = self.p5_conv(in5)
p4 = self.p4_conv(out4)
p3 = self.p3_conv(out3)
p2 = self.p2_conv(out2)
p5 = F.interpolate(
p5,
scale_factor=8,
mode="nearest",
) # align_mode=1)
p4 = F.interpolate(
p4,
scale_factor=4,
mode="nearest",
) # align_mode=1)
p3 = F.interpolate(
p3,
scale_factor=2,
mode="nearest",
) # align_mode=1)
fuse = torch.cat([p5, p4, p3, p2], dim=1)
if self.use_asf is True:
fuse = self.asf(fuse, [p5, p4, p3, p2])
return fuse
class RSELayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
super(RSELayer, self).__init__()
self.out_channels = out_channels
self.in_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
padding=int(kernel_size // 2),
bias=False,
)
self.se_block = SEModule(self.out_channels)
self.shortcut = shortcut
def forward(self, ins):
x = self.in_conv(ins)
if self.shortcut:
out = x + self.se_block(x)
else:
out = self.se_block(x)
return out
class RSEFPN(nn.Module):
def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
super(RSEFPN, self).__init__()
self.out_channels = out_channels
self.ins_conv = nn.ModuleList()
self.inp_conv = nn.ModuleList()
self.intracl = False
if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
self.intracl = kwargs["intracl"]
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
for i in range(len(in_channels)):
self.ins_conv.append(
RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut)
)
self.inp_conv.append(
RSELayer(
out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut
)
)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.ins_conv[3](c5)
in4 = self.ins_conv[2](c4)
in3 = self.ins_conv[1](c3)
in2 = self.ins_conv[0](c2)
out4 = in4 + F.interpolate(in5, scale_factor=2, mode="nearest") # 1/16
out3 = in3 + F.interpolate(out4, scale_factor=2, mode="nearest") # 1/8
out2 = in2 + F.interpolate(out3, scale_factor=2, mode="nearest") # 1/4
p5 = self.inp_conv[3](in5)
p4 = self.inp_conv[2](out4)
p3 = self.inp_conv[1](out3)
p2 = self.inp_conv[0](out2)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
fuse = torch.cat([p5, p4, p3, p2], dim=1)
return fuse
class LKPAN(nn.Module):
def __init__(self, in_channels, out_channels, mode="large", **kwargs):
super(LKPAN, self).__init__()
self.out_channels = out_channels
self.ins_conv = nn.ModuleList()
self.inp_conv = nn.ModuleList()
# pan head
self.pan_head_conv = nn.ModuleList()
self.pan_lat_conv = nn.ModuleList()
if mode.lower() == "lite":
p_layer = DSConv
elif mode.lower() == "large":
p_layer = nn.Conv2d
else:
raise ValueError(
"mode can only be one of ['lite', 'large'], but received {}".format(
mode
)
)
for i in range(len(in_channels)):
self.ins_conv.append(
nn.Conv2d(
in_channels=in_channels[i],
out_channels=self.out_channels,
kernel_size=1,
bias=False,
)
)
self.inp_conv.append(
p_layer(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=9,
padding=4,
bias=False,
)
)
if i > 0:
self.pan_head_conv.append(
nn.Conv2d(
in_channels=self.out_channels // 4,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
self.pan_lat_conv.append(
p_layer(
in_channels=self.out_channels // 4,
out_channels=self.out_channels // 4,
kernel_size=9,
padding=4,
bias=False,
)
)
self.intracl = False
if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
self.intracl = kwargs["intracl"]
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.ins_conv[3](c5)
in4 = self.ins_conv[2](c4)
in3 = self.ins_conv[1](c3)
in2 = self.ins_conv[0](c2)
out4 = in4 + F.interpolate(in5, scale_factor=2, mode="nearest") # 1/16
out3 = in3 + F.interpolate(out4, scale_factor=2, mode="nearest") # 1/8
out2 = in2 + F.interpolate(out3, scale_factor=2, mode="nearest") # 1/4
f5 = self.inp_conv[3](in5)
f4 = self.inp_conv[2](out4)
f3 = self.inp_conv[1](out3)
f2 = self.inp_conv[0](out2)
pan3 = f3 + self.pan_head_conv[0](f2)
pan4 = f4 + self.pan_head_conv[1](pan3)
pan5 = f5 + self.pan_head_conv[2](pan4)
p2 = self.pan_lat_conv[0](f2)
p3 = self.pan_lat_conv[1](pan3)
p4 = self.pan_lat_conv[2](pan4)
p5 = self.pan_lat_conv[3](pan5)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
fuse = torch.cat([p5, p4, p3, p2], dim=1)
return fuse
class ASFBlock(nn.Module):
"""
This code is refered from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""
def __init__(self, in_channels, inter_channels, out_features_num=4):
"""
Adaptive Scale Fusion (ASF) block of DBNet++
Args:
in_channels: the number of channels in the input data
inter_channels: the number of middle channels
out_features_num: the number of fused stages
"""
super(ASFBlock, self).__init__()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.out_features_num = out_features_num
self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
self.spatial_scale = nn.Sequential(
# Nx1xHxW
nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=3,
bias=False,
padding=1,
),
nn.ReLU(),
nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=1,
bias=False,
),
nn.Sigmoid(),
)
self.channel_scale = nn.Sequential(
nn.Conv2d(
in_channels=inter_channels,
out_channels=out_features_num,
kernel_size=1,
bias=False,
),
nn.Sigmoid(),
)
def forward(self, fuse_features, features_list):
fuse_features = self.conv(fuse_features)
spatial_x = torch.mean(fuse_features, dim=1, keepdim=True)
attention_scores = self.spatial_scale(spatial_x) + fuse_features
attention_scores = self.channel_scale(attention_scores)
assert len(features_list) == self.out_features_num
out_list = []
for i in range(self.out_features_num):
out_list.append(attention_scores[:, i : i + 1] * features_list[i])
return torch.cat(out_list, dim=1)
from torch import nn
class IntraCLBlock(nn.Module):
def __init__(self, in_channels=96, reduce_factor=4):
super(IntraCLBlock, self).__init__()
self.channels = in_channels
self.rf = reduce_factor
self.conv1x1_reduce_channel = nn.Conv2d(
self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0
)
self.conv1x1_return_channel = nn.Conv2d(
self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0
)
self.v_layer_7x1 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 1),
stride=(1, 1),
padding=(3, 0),
)
self.v_layer_5x1 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 1),
stride=(1, 1),
padding=(2, 0),
)
self.v_layer_3x1 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 1),
stride=(1, 1),
padding=(1, 0),
)
self.q_layer_1x7 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 7),
stride=(1, 1),
padding=(0, 3),
)
self.q_layer_1x5 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 5),
stride=(1, 1),
padding=(0, 2),
)
self.q_layer_1x3 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 3),
stride=(1, 1),
padding=(0, 1),
)
# base
self.c_layer_7x7 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 7),
stride=(1, 1),
padding=(3, 3),
)
self.c_layer_5x5 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
)
self.c_layer_3x3 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
)
self.bn = nn.BatchNorm2d(self.channels)
self.relu = nn.ReLU()
def forward(self, x):
x_new = self.conv1x1_reduce_channel(x)
x_7_c = self.c_layer_7x7(x_new)
x_7_v = self.v_layer_7x1(x_new)
x_7_q = self.q_layer_1x7(x_new)
x_7 = x_7_c + x_7_v + x_7_q
x_5_c = self.c_layer_5x5(x_7)
x_5_v = self.v_layer_5x1(x_7)
x_5_q = self.q_layer_1x5(x_7)
x_5 = x_5_c + x_5_v + x_5_q
x_3_c = self.c_layer_3x3(x_5)
x_3_v = self.v_layer_3x1(x_5)
x_3_q = self.q_layer_1x3(x_5)
x_3 = x_3_c + x_3_v + x_3_q
x_relation = self.conv1x1_return_channel(x_3)
x_relation = self.bn(x_relation)
x_relation = self.relu(x_relation)
return x + x_relation
def build_intraclblock_list(num_block):
IntraCLBlock_list = nn.ModuleList()
for i in range(num_block):
IntraCLBlock_list.append(IntraCLBlock())
return IntraCLBlock_list
import torch
from torch import nn
from ..backbones.rec_svtrnet import Block, ConvBNLayer
class Im2Seq(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
# def forward(self, x):
# B, C, H, W = x.shape
# # assert H == 1
# x = x.squeeze(dim=2)
# # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
# x = x.permute(0, 2, 1)
# return x
def forward(self, x):
B, C, H, W = x.shape
# 处理四维张量,将空间维度展平为序列
if H == 1:
# 原来的处理逻辑,适用于H=1的情况
x = x.squeeze(dim=2)
x = x.permute(0, 2, 1) # (B, W, C)
else:
# 处理H不为1的情况
x = x.permute(0, 2, 3, 1) # (B, H, W, C)
x = x.reshape(B, H * W, C) # (B, H*W, C)
return x
class EncoderWithRNN_(nn.Module):
def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN_, self).__init__()
self.out_channels = hidden_size * 2
self.rnn1 = nn.LSTM(
in_channels,
hidden_size,
bidirectional=False,
batch_first=True,
num_layers=2,
)
self.rnn2 = nn.LSTM(
in_channels,
hidden_size,
bidirectional=False,
batch_first=True,
num_layers=2,
)
def forward(self, x):
self.rnn1.flatten_parameters()
self.rnn2.flatten_parameters()
out1, h1 = self.rnn1(x)
out2, h2 = self.rnn2(torch.flip(x, [1]))
return torch.cat([out1, torch.flip(out2, [1])], 2)
class EncoderWithRNN(nn.Module):
def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN, self).__init__()
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(
in_channels, hidden_size, num_layers=2, batch_first=True, bidirectional=True
) # batch_first:=True
def forward(self, x):
x, _ = self.lstm(x)
return x
class EncoderWithFC(nn.Module):
def __init__(self, in_channels, hidden_size):
super(EncoderWithFC, self).__init__()
self.out_channels = hidden_size
self.fc = nn.Linear(
in_channels,
hidden_size,
bias=True,
)
def forward(self, x):
x = self.fc(x)
return x
class EncoderWithSVTR(nn.Module):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
kernel_size=[3, 3],
attn_drop_rate=0.1,
drop_path=0.0,
qk_scale=None,
):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
act="swish",
)
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act="swish"
)
self.svtr_block = nn.ModuleList(
[
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer="Global",
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer="swish",
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer="nn.LayerNorm",
epsilon=1e-05,
prenorm=False,
)
for i in range(depth)
]
)
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act="swish"
)
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).permute(0, 2, 1)
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
z = self.conv3(z)
z = torch.cat((h, z), dim=1)
z = self.conv1x1(self.conv4(z))
return z
class SequenceEncoder(nn.Module):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == "reshape":
self.only_reshape = True
else:
support_encoder_dict = {
"reshape": Im2Seq,
"fc": EncoderWithFC,
"rnn": EncoderWithRNN,
"svtr": EncoderWithSVTR,
}
assert encoder_type in support_encoder_dict, "{} must in {}".format(
encoder_type, support_encoder_dict.keys()
)
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, **kwargs
)
else:
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size
)
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
if self.encoder_type != "svtr":
x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x)
x = self.encoder_reshape(x)
return x
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
__all__ = ['build_post_process']
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, TableLabelDecode, \
NRTRLabelDecode, SARLabelDecode, ViTSTRLabelDecode, RFLLabelDecode
from .cls_postprocess import ClsPostProcess
from .rec_postprocess import CANLabelDecode
support_dict = [
'DBPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
'TableLabelDecode', 'NRTRLabelDecode', 'SARLabelDecode',
'ViTSTRLabelDecode','CANLabelDecode', 'RFLLabelDecode'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
'post process only support {}, but got {}'.format(support_dict, module_name))
module_class = eval(module_name)(**config)
return module_class
\ No newline at end of file
import torch
class ClsPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, label_list, **kwargs):
super(ClsPostProcess, self).__init__()
self.label_list = label_list
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
pred_idxs = preds.argmax(axis=1)
decode_out = [(self.label_list[idx], preds[i, idx])
for i, idx in enumerate(pred_idxs)]
if label is None:
return decode_out
label = [(self.label_list[idx], 1.0) for idx in label]
return decode_out, label
\ No newline at end of file
"""
This code is refered from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import torch
from shapely.geometry import Polygon
import pyclipper
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, torch.Tensor):
pred = pred.cpu().numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes})
return boxes_batch
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, torch.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, torch.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, torch.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
else:
if isinstance(preds, torch.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
try:
char_idx = self.character[int(text_index[batch_idx][idx])]
except:
continue
if char_idx == '</s>': # end
break
char_list.append(char_idx)
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list
class ViTSTRLabelDecode(NRTRLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(ViTSTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, torch.Tensor):
preds = preds[:, 1:].numpy()
else:
preds = preds[:, 1:]
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class RFLLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(RFLLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
# if seq_outputs is not None:
if isinstance(preds, tuple) or isinstance(preds, list):
cnt_outputs, seq_outputs = preds
if isinstance(seq_outputs, torch.Tensor):
seq_outputs = seq_outputs.numpy()
preds_idx = seq_outputs.argmax(axis=2)
preds_prob = seq_outputs.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
else:
cnt_outputs = preds
if isinstance(cnt_outputs, torch.Tensor):
cnt_outputs = cnt_outputs.numpy()
cnt_length = []
for lens in cnt_outputs:
length = round(np.sum(lens))
cnt_length.append(length)
if label is None:
return cnt_length
label = self.decode(label, is_remove_duplicate=False)
length = [len(res[0]) for res in label]
return cnt_length, length
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
use_space_char=False,
**kwargs):
self.max_text_length = kwargs.get('max_text_length', 25)
super(SRNLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
char_num = len(self.character_str) + 2
if isinstance(pred, torch.Tensor):
pred = pred.numpy()
pred = np.reshape(pred, [-1, char_num])
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
text = self.decode(preds_idx, preds_prob)
if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text
label = self.decode(label)
return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelDecode(object):
""" """
def __init__(self,
character_dict_path,
**kwargs):
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,torch.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,torch.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx == 0:
continue
else:
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if self.rm_symbol:
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
return [self.padding_idx]
class CANLabelDecode(BaseRecLabelDecode):
""" Convert between latex-symbol and symbol-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CANLabelDecode, self).__init__(character_dict_path,
use_space_char)
def decode(self, text_index, preds_prob=None):
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
seq_end = text_index[batch_idx].argmin(0)
idx_list = text_index[batch_idx][:seq_end].tolist()
symbol_list = [self.character[idx] for idx in idx_list]
probs = []
if preds_prob is not None:
probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
result_list.append([' '.join(symbol_list), probs])
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
pred_prob, _, _, _ = preds
preds_idx = pred_prob.argmax(axis=2)
text = self.decode(preds_idx)
if label is None:
return text
label = self.decode(label)
return text, label
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment