Commit a565fa3a authored by luopl's avatar luopl
Browse files

Initial commit

parents
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class DecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Resize(object):
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize(object):
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
__all__ = ["build_model"]
def build_model(config, **kwargs):
from .base_model import BaseModel
config = copy.deepcopy(config)
module_class = BaseModel(config, **kwargs)
return module_class
from torch import nn
from ..backbones import build_backbone
from ..heads import build_head
from ..necks import build_neck
class BaseModel(nn.Module):
def __init__(self, config, **kwargs):
"""
the module for OCR.
args:
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()
in_channels = config.get("in_channels", 3)
model_type = config["model_type"]
# build backbone, backbone is need for del, rec and cls
if "Backbone" not in config or config["Backbone"] is None:
self.use_backbone = False
else:
self.use_backbone = True
config["Backbone"]["in_channels"] = in_channels
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
# for cls, neck should be none
if "Neck" not in config or config["Neck"] is None:
self.use_neck = False
else:
self.use_neck = True
config["Neck"]["in_channels"] = in_channels
self.neck = build_neck(config["Neck"])
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
if "Head" not in config or config["Head"] is None:
self.use_head = False
else:
self.use_head = True
config["Head"]["in_channels"] = in_channels
self.head = build_head(config["Head"], **kwargs)
self.return_all_feats = config.get("return_all_feats", False)
self._initialize_weights()
def _initialize_weights(self):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
y = dict()
if self.use_backbone:
x = self.backbone(x)
if isinstance(x, dict):
y.update(x)
else:
y["backbone_out"] = x
final_name = "backbone_out"
if self.use_neck:
x = self.neck(x)
if isinstance(x, dict):
y.update(x)
else:
y["neck_out"] = x
final_name = "neck_out"
if self.use_head:
x = self.head(x)
# for multi head, save ctc neck out for udml
if isinstance(x, dict) and "ctc_nect" in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
if self.return_all_feats:
if self.training:
return y
elif isinstance(x, dict):
return x
else:
return {final_name: x}
else:
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["build_backbone"]
def build_backbone(config, model_type):
if model_type == "det":
from .det_mobilenet_v3 import MobileNetV3
from .rec_hgnet import PPHGNet_small
from .rec_lcnetv3 import PPLCNetV3
from .rec_pphgnetv2 import PPHGNetV2_B4
support_dict = [
"MobileNetV3",
"ResNet",
"ResNet_vd",
"ResNet_SAST",
"PPLCNetV3",
"PPHGNet_small",
'PPHGNetV2_B4',
]
elif model_type == "rec" or model_type == "cls":
from .rec_hgnet import PPHGNet_small
from .rec_lcnetv3 import PPLCNetV3
from .rec_mobilenet_v3 import MobileNetV3
from .rec_svtrnet import SVTRNet
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_pphgnetv2 import PPHGNetV2_B4
support_dict = [
"MobileNetV1Enhance",
"MobileNetV3",
"ResNet",
"ResNetFPN",
"MTB",
"ResNet31",
"SVTRNet",
"ViTSTR",
"DenseNet",
"PPLCNetV3",
"PPHGNet_small",
"PPHGNetV2_B4",
]
else:
raise NotImplementedError
module_name = config.pop("name")
assert module_name in support_dict, Exception(
"when model typs is {}, backbone only support {}".format(
model_type, support_dict
)
)
module_class = eval(module_name)(**config)
return module_class
from torch import nn
from ..common import Activation
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None,
):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(
out_channels,
)
if self.if_act:
self.act = Activation(act_type=act, inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
x = self.act(x)
return x
class SEModule(nn.Module):
def __init__(self, in_channels, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.relu1 = Activation(act_type="relu", inplace=True)
self.conv2 = nn.Conv2d(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.hard_sigmoid = Activation(act_type="hard_sigmoid", inplace=True)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = self.relu1(outputs)
outputs = self.conv2(outputs)
outputs = self.hard_sigmoid(outputs)
outputs = inputs * outputs
return outputs
class ResidualUnit(nn.Module):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
kernel_size,
stride,
use_se,
act=None,
name="",
):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand",
)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act,
name=name + "_depthwise",
)
if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear",
)
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = inputs + x
return x
class MobileNetV3(nn.Module):
def __init__(
self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hard_swish", 2],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1],
[5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2],
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hard_swish", 2],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1],
[5, 288, 96, True, "hard_swish", 2],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError(
"mode[" + model_name + "_model] is not implemented!"
)
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert (
scale in supported_scale
), "supported scale are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act="hard_swish",
name="conv1",
)
self.stages = nn.ModuleList()
self.out_channels = []
block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for k, exp, c, se, nl, s in cfg:
se = se and not self.disable_se
if s == 2 and i > 2:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2),
)
)
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act="hard_swish",
name="conv_last",
)
)
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
# for i, stage in enumerate(self.stages):
# self.add_sublayer(sublayer=stage, name="stage{}".format(i))
def forward(self, x):
x = self.conv(x)
out_list = []
for stage in self.stages:
x = stage(x)
out_list.append(x)
return out_list
import collections.abc
from collections import OrderedDict
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class DonutSwinConfig(object):
model_type = "donut-swin"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
initializer_range=0.02,
layer_norm_eps=1e-5,
**kwargs,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
print(f"Can't set {key} with value {value} for {self}")
raise err
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
class DonutSwinEncoderOutput(OrderedDict):
last_hidden_state = None
hidden_states = None
attentions = None
reshaped_hidden_states = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
super().__setattr__(key, value)
def to_tuple(self):
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
class DonutSwinModelOutput(OrderedDict):
last_hidden_state = None
pooler_output = None
hidden_states = None
attentions = None
reshaped_hidden_states = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
super().__setattr__(key, value)
def to_tuple(self):
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.reshape(
[
batch_size,
height // window_size,
window_size,
width // window_size,
window_size,
num_channels,
]
)
windows = input_feature.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, window_size, window_size, num_channels]
)
return windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape[-1]
windows = windows.reshape(
[
-1,
height // window_size,
width // window_size,
window_size,
window_size,
num_channels,
]
)
windows = windows.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, height, width, num_channels]
)
return windows
# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
class DonutSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def __init__(self, config, use_mask_token=False):
super().__init__()
self.patch_embeddings = DonutSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
if use_mask_token:
# self.mask_token = paddle.create_parameter(
# [1, 1, config.embed_dim], dtype="float32"
# )
self.mask_token = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(1, 1, config.embed_dim).to(torch.float32))
)
nn.init.zeros_(self.mask_token)
else:
self.mask_token = None
if config.use_absolute_embeddings:
# self.position_embeddings = paddle.create_parameter(
# [1, num_patches + 1, config.embed_dim], dtype="float32"
# )
self.position_embeddings = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(1, num_patches + 1, config.embed_dim).to(torch.float32))
)
nn.init.zeros_(self.position_embedding)
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.shape
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class MyConv2d(nn.Conv2d):
def __init__(
self,
in_channel,
out_channels,
kernel_size,
stride=1,
padding="SAME",
dilation=1,
groups=1,
bias_attr=False,
eps=1e-6,
):
super().__init__(
in_channel,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias_attr,
)
# self.weight = paddle.create_parameter(
# [out_channels, in_channel, kernel_size[0], kernel_size[1]], dtype="float32"
# )
self.weight = torch.Parameter(
nn.init.xavier_uniform_(
torch.zeros(out_channels, in_channel, kernel_size[0], kernel_size[1]).to(torch.float32)
)
)
# self.bias = paddle.create_parameter([out_channels], dtype="float32")
self.bias = torch.Parameter(
nn.init.xavier_uniform_(
torch.zeros(out_channels).to(torch.float32)
)
)
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
x = F.conv2d(
x,
self.weight,
self.bias,
self._stride,
self._padding,
self._dilation,
self._groups,
)
return x
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
class DonutSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = (
image_size
if isinstance(image_size, collections.abc.Iterable)
else (image_size, image_size)
)
patch_size = (
patch_size
if isinstance(patch_size, collections.abc.Iterable)
else (patch_size, patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (
image_size[0] // patch_size[0]
)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.is_export = config.is_export
self.grid_size = (
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
)
self.projection = nn.Conv2D(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
if self.is_export:
pad_values = torch.tensor(pad_values, dtype=torch.int32)
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
if self.is_export:
pad_values = torch.tensor(pad_values, dtype=torch.int32)
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose([0, 2, 1])
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class DonutSwinPatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(
self,
input_resolution: Tuple[int],
dim: int,
norm_layer: nn.Module = nn.LayerNorm,
is_export=False,
):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = norm_layer(4 * dim)
self.is_export = is_export
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
if self.is_export:
pad_values = torch.tensor(pad_values, dtype=torch.int32)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(
self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
) -> torch.Tensor:
height, width = input_dimensions
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.reshape([batch_size, height, width, num_channels])
input_feature = self.maybe_pad(input_feature, height, width)
input_feature_0 = input_feature[:, 0::2, 0::2, :]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
input_feature = torch.cat(
[input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
)
input_feature = input_feature.reshape(
[batch_size, -1, 4 * num_channels]
) # batch_size height/2*width/2 4*C
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (
input.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(
shape,
dtype=input.dtype,
)
random_tensor.floor_() # binarize
output = input / keep_prob * random_tensor
return output
# Copied from transformers.models.swin.modeling_swin.SwinDropPath
class DonutSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class DonutSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size
if isinstance(window_size, collections.abc.Iterable)
else (window_size, window_size)
)
# self.relative_position_bias_table = paddle.create_parameter(
# [(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads],
# dtype="float32",
# )
self.relative_position_bias_table = torch.Parameter(
nn.init.xavier_normal_(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads).to(torch.float32)
)
)
nn.init.zeros_(self.relative_position_bias_table)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.transpose([1, 2, 0])
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(
self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
)
self.key = nn.Linear(
self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
)
self.value = nn.Linear(
self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.shape[:-1] + [
self.num_attention_heads,
self.attention_head_size,
]
x = x.reshape(new_x_shape)
return x.transpose([0, 2, 1, 3])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask=None,
head_mask=None,
output_attentions=False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose([0, 1, 3, 2]))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.reshape([-1])
]
relative_position_bias = relative_position_bias.reshape(
[
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
]
)
relative_position_bias = relative_position_bias.transpose([2, 0, 1])
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.reshape(
[
batch_size // mask_shape,
mask_shape,
self.num_attention_heads,
dim,
dim,
]
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(
0
)
attention_scores = attention_scores.reshape(
[-1, self.num_attention_heads, dim, dim]
)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose([0, 2, 1, 3])
new_context_layer_shape = tuple(context_layer.shape[:-2]) + (
self.all_head_size,
)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
class DonutSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
class DonutSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size)
self.output = DonutSwinSelfOutput(config, dim)
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask=None,
head_mask=None,
output_attentions=False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
class DonutSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
self.intermediate_act_fn = F.gelu
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput
class DonutSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
class DonutSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(
config, dim, num_heads, window_size=self.window_size
)
self.drop_path = (
DonutSwinDropPath(config.drop_path_rate)
if config.drop_path_rate > 0.0
else nn.Identity()
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DonutSwinIntermediate(config, dim)
self.output = DonutSwinOutput(config, dim)
self.is_export = config.is_export
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask_export(self, height, width, dtype):
attn_mask = None
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
if self.shift_size > 0:
img_mask[:, height_slice, width_slice, :] = count
count += 1
if torch.Tensor(self.shift_size > 0).to(torch.bool):
# calculate attention mask for SW-MSA
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size]
)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size]
)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, 0, pad_bottom, 0, pad_right, 0, 0)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask=None,
output_attentions=False,
always_partition=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not always_partition:
self.set_shift_and_window_size(input_dimensions)
else:
pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.shape
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.reshape([batch_size, height, width, channels])
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shift_value = (-self.shift_size, -self.shift_size)
if self.is_export:
shift_value = torch.tensor(shift_value, dtype=torch.int32)
shifted_hidden_states = torch.roll(
hidden_states, shifts=shift_value, dims=(1, 2)
)
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(
shifted_hidden_states, self.window_size
)
hidden_states_windows = hidden_states_windows.reshape(
[-1, self.window_size * self.window_size, channels]
)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
attention_outputs = self.attention(
hidden_states_windows,
attn_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
attention_windows = attention_output.reshape(
[-1, self.window_size, self.window_size, channels]
)
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
)
# reverse cyclic shift
if self.shift_size > 0:
shift_value = (self.shift_size, self.shift_size)
if self.is_export:
shift_value = torch.tensor(shift_value, dtype=torch.int32)
attention_windows = torch.roll(
shifted_windows, shifts=shift_value, dims=(1, 2)
)
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.reshape(
[batch_size, height * width, channels]
)
hidden_states = shortcut + self.drop_path(attention_windows)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_outputs = (
(layer_output, attention_outputs[1])
if output_attentions
else (layer_output,)
)
return layer_outputs
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class DonutSwinStage(nn.Module):
def __init__(
self, config, dim, input_resolution, depth, num_heads, drop_path, downsample
):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
DonutSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
self.is_export = config.is_export
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=nn.LayerNorm,
is_export=self.is_export,
)
else:
self.downsample = None
self.pointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask=None,
output_attentions=False,
always_partition=False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(
hidden_states_before_downsampling, input_dimensions
)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (
hidden_states,
hidden_states_before_downsampling,
output_dimensions,
)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
class DonutSwinEncoder(nn.Module):
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [
x.item()
for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))
]
self.layers = nn.ModuleList(
[
DonutSwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(
grid_size[0] // (2**i_layer),
grid_size[1] // (2**i_layer),
),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[
sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])
],
downsample=(
DonutSwinPatchMerging
if (i_layer < self.num_layers - 1)
else None
),
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask=None,
output_attentions=False,
output_hidden_states=False,
output_hidden_states_before_downsampling=False,
always_partition=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape
reshaped_hidden_state = hidden_states.view(
batch_size, *input_dimensions, hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
reshaped_hidden_state = hidden_states_before_downsampling.reshape(
[
batch_size,
*(output_dimensions[0], output_dimensions[1]),
hidden_size,
]
)
reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
reshaped_hidden_state = hidden_states.reshape(
[batch_size, *input_dimensions, hidden_size]
)
reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions:
all_self_attentions += layer_outputs[3:]
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions]
if v is not None
)
return DonutSwinEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
)
class DonutSwinPreTrainedModel(nn.Module):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DonutSwinConfig
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2D)):
# normal_ = Normal(mean=0.0, std=self.config.initializer_range)
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def _initialize_weights(self, module):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
def post_init(self):
self.apply(self._initialize_weights)
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
class DonutSwinModel(DonutSwinPreTrainedModel):
def __init__(
self,
in_channels=3,
hidden_size=1024,
num_layers=4,
num_heads=[4, 8, 16, 32],
add_pooling_layer=True,
use_mask_token=False,
is_export=False,
):
super().__init__()
donut_swin_config = {
"return_dict": True,
"output_hidden_states": False,
"output_attentions": False,
"use_bfloat16": False,
"tf_legacy_loss": False,
"pruned_heads": {},
"tie_word_embeddings": True,
"chunk_size_feed_forward": 0,
"is_encoder_decoder": False,
"is_decoder": False,
"cross_attention_hidden_size": None,
"add_cross_attention": False,
"tie_encoder_decoder": False,
"max_length": 20,
"min_length": 0,
"do_sample": False,
"early_stopping": False,
"num_beams": 1,
"num_beam_groups": 1,
"diversity_penalty": 0.0,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
"typical_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"encoder_no_repeat_ngram_size": 0,
"bad_words_ids": None,
"num_return_sequences": 1,
"output_scores": False,
"return_dict_in_generate": False,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"remove_invalid_values": False,
"exponential_decay_length_penalty": None,
"suppress_tokens": None,
"begin_suppress_tokens": None,
"architectures": None,
"finetuning_task": None,
"id2label": {0: "LABEL_0", 1: "LABEL_1"},
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
"tokenizer_class": None,
"prefix": None,
"bos_token_id": None,
"pad_token_id": None,
"eos_token_id": None,
"sep_token_id": None,
"decoder_start_token_id": None,
"task_specific_params": None,
"problem_type": None,
"_name_or_path": "",
"_commit_hash": None,
"_attn_implementation_internal": None,
"transformers_version": None,
"hidden_size": hidden_size,
"num_layers": num_layers,
"path_norm": True,
"use_2d_embeddings": False,
"image_size": [420, 420],
"patch_size": 4,
"num_channels": in_channels,
"embed_dim": 128,
"depths": [2, 2, 14, 2],
"num_heads": num_heads,
"window_size": 5,
"mlp_ratio": 4.0,
"qkv_bias": True,
"hidden_dropout_prob": 0.0,
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.1,
"hidden_act": "gelu",
"use_absolute_embeddings": False,
"layer_norm_eps": 1e-05,
"initializer_range": 0.02,
"is_export": is_export,
}
config = DonutSwinConfig(**donut_swin_config)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
self.pooler = nn.AdaptiveAvgPool1D(1) if add_pooling_layer else None
self.out_channels = hidden_size
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(
self,
input_data=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple, DonutSwinModelOutput]:
r"""
bool_masked_pos (`paddle.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
if self.training:
pixel_values, label, attention_mask = input_data
else:
if isinstance(input_data, list):
pixel_values = input_data[0]
else:
pixel_values = input_data
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos
)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose([0, 2, 1]))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
output = (sequence_output, pooled_output) + encoder_outputs[1:]
return output
donut_swin_output = DonutSwinModelOutput(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
if self.training:
return donut_swin_output, label, attention_mask
else:
return donut_swin_output
\ No newline at end of file
import torch
import torch.nn.functional as F
from torch import nn
class ConvBNAct(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
):
super().__init__()
self.use_act = use_act
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(out_channels)
if self.use_act:
self.act = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
return x
class ESEModule(nn.Module):
def __init__(self, channels):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv(x)
x = self.sigmoid(x)
return x * identity
class HG_Block(nn.Module):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False,
):
super().__init__()
self.identity = identity
self.layers = nn.ModuleList()
self.layers.append(
ConvBNAct(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1,
)
)
for _ in range(layer_num - 1):
self.layers.append(
ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1,
)
)
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
self.aggregation_conv = ConvBNAct(
in_channels=total_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
)
self.att = ESEModule(out_channels)
def forward(self, x):
identity = x
output = []
output.append(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
x = self.aggregation_conv(x)
x = self.att(x)
if self.identity:
x += identity
return x
class HG_Stage(nn.Module):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
downsample=True,
stride=[2, 1],
):
super().__init__()
self.downsample = downsample
if downsample:
self.downsample = ConvBNAct(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
groups=in_channels,
use_act=False,
)
blocks_list = []
blocks_list.append(
HG_Block(in_channels, mid_channels, out_channels, layer_num, identity=False)
)
for _ in range(block_num - 1):
blocks_list.append(
HG_Block(
out_channels, mid_channels, out_channels, layer_num, identity=True
)
)
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
if self.downsample:
x = self.downsample(x)
x = self.blocks(x)
return x
class PPHGNet(nn.Module):
"""
PPHGNet
Args:
stem_channels: list. Stem channel list of PPHGNet.
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
layer_num: int. Number of layers of HG_Block.
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def __init__(
self,
stem_channels,
stage_config,
layer_num,
in_channels=3,
det=False,
out_indices=None,
):
super().__init__()
self.det = det
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
# stem
stem_channels.insert(0, in_channels)
self.stem = nn.Sequential(
*[
ConvBNAct(
in_channels=stem_channels[i],
out_channels=stem_channels[i + 1],
kernel_size=3,
stride=2 if i == 0 else 1,
)
for i in range(len(stem_channels) - 1)
]
)
if self.det:
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# stages
self.stages = nn.ModuleList()
self.out_channels = []
for block_id, k in enumerate(stage_config):
(
in_channels,
mid_channels,
out_channels,
block_num,
downsample,
stride,
) = stage_config[k]
self.stages.append(
HG_Stage(
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
downsample,
stride,
)
)
if block_id in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
if self.det:
x = self.pool(x)
out = []
for i, stage in enumerate(self.stages):
x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
"""
PPHGNet_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config_det = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, False, 2],
"stage2": [256, 160, 512, 1, True, 2],
"stage3": [512, 192, 768, 2, True, 2],
"stage4": [768, 224, 1024, 1, True, 2],
}
stage_config_rec = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, True, [2, 1]],
"stage2": [256, 160, 512, 1, True, [1, 2]],
"stage3": [512, 192, 768, 2, True, [2, 1]],
"stage4": [768, 224, 1024, 1, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[64, 64, 128],
stage_config=stage_config_det if det else stage_config_rec,
layer_num=6,
det=det,
**kwargs
)
return model
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import torch
import torch.nn.functional as F
from torch import nn
from ..common import Activation
NET_CONFIG_det = {
"blocks2":
# k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
"blocks5": [
[3, 128, 256, 2, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
],
"blocks6": [
[5, 256, 512, 2, True],
[5, 512, 512, 1, True],
[5, 512, 512, 1, False],
[5, 512, 512, 1, False],
],
}
NET_CONFIG_rec = {
"blocks2":
# k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
"blocks5": [
[3, 128, 256, (1, 2), False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
[5, 256, 256, 1, False],
],
"blocks6": [
[5, 256, 512, (2, 1), True],
[5, 512, 512, 1, True],
[5, 512, 512, (2, 1), False],
[5, 512, 512, 1, False],
],
}
def make_divisible(v, divisor=16, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class LearnableAffineBlock(nn.Module):
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
super().__init__()
self.scale = nn.Parameter(torch.Tensor([scale_value]))
self.bias = nn.Parameter(torch.Tensor([bias_value]))
def forward(self, x):
return self.scale * x + self.bias
class ConvBNLayer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(
out_channels,
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Act(nn.Module):
def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
super().__init__()
if act == "hswish":
self.act = nn.Hardswish(inplace=True)
else:
assert act == "relu"
self.act = Activation(act)
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
def forward(self, x):
return self.lab(self.act(x))
class LearnableRepLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
num_conv_branches=1,
lr_mult=1.0,
lab_lr=0.1,
):
super().__init__()
self.is_repped = False
self.groups = groups
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
self.padding = (kernel_size - 1) // 2
self.identity = (
nn.BatchNorm2d(
num_features=in_channels,
)
if out_channels == in_channels and stride == 1
else None
)
self.conv_kxk = nn.ModuleList(
[
ConvBNLayer(
in_channels,
out_channels,
kernel_size,
stride,
groups=groups,
lr_mult=lr_mult,
)
for _ in range(self.num_conv_branches)
]
)
self.conv_1x1 = (
ConvBNLayer(
in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
)
if kernel_size > 1
else None
)
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
def forward(self, x):
# for export
if self.is_repped:
out = self.lab(self.reparam_conv(x))
if self.stride != 2:
out = self.act(out)
return out
out = 0
if self.identity is not None:
out += self.identity(x)
if self.conv_1x1 is not None:
out += self.conv_1x1(x)
for conv in self.conv_kxk:
out += conv(x)
out = self.lab(out)
if self.stride != 2:
out = self.act(out)
return out
def rep(self):
if self.is_repped:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.groups,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
self.is_repped = True
def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
if not isinstance(kernel1x1, torch.Tensor):
return 0
else:
return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
def _get_kernel_bias(self):
kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(
kernel_conv_1x1, self.kernel_size // 2
)
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
kernel_conv_kxk = 0
bias_conv_kxk = 0
for conv in self.conv_kxk:
kernel, bias = self._fuse_bn_tensor(conv)
kernel_conv_kxk += kernel
bias_conv_kxk += bias
kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
return kernel_reparam, bias_reparam
def _fuse_bn_tensor(self, branch):
if not branch:
return 0, 0
elif isinstance(branch, ConvBNLayer):
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros(
(self.in_channels, input_dim, self.kernel_size, self.kernel_size),
dtype=branch.weight.dtype,
)
for i in range(self.in_channels):
kernel_value[
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch._mean
running_var = branch._variance
gamma = branch.weight
beta = branch.bias
eps = branch._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
class SELayer(nn.Module):
def __init__(self, channel, reduction=4, lr_mult=1.0):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
)
self.hardsigmoid = nn.Hardsigmoid(inplace=True)
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = identity * x
return x
class LCNetV3Block(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride,
dw_size,
use_se=False,
conv_kxk_num=4,
lr_mult=1.0,
lab_lr=0.1,
):
super().__init__()
self.use_se = use_se
self.dw_conv = LearnableRepLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=dw_size,
stride=stride,
groups=in_channels,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
lab_lr=lab_lr,
)
if use_se:
self.se = SELayer(in_channels, lr_mult=lr_mult)
self.pw_conv = LearnableRepLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
lab_lr=lab_lr,
)
def forward(self, x):
x = self.dw_conv(x)
if self.use_se:
x = self.se(x)
x = self.pw_conv(x)
return x
class PPLCNetV3(nn.Module):
def __init__(
self,
scale=1.0,
conv_kxk_num=4,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lab_lr=0.1,
det=False,
**kwargs
):
super().__init__()
self.scale = scale
self.lr_mult_list = lr_mult_list
self.det = det
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
assert isinstance(
self.lr_mult_list, (list, tuple)
), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list)
)
assert (
len(self.lr_mult_list) == 6
), "lr_mult_list length should be 6 but got {}".format(len(self.lr_mult_list))
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=make_divisible(16 * scale),
kernel_size=3,
stride=2,
lr_mult=self.lr_mult_list[0],
)
self.blocks2 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[1],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks2"])
]
)
self.blocks3 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[2],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks3"])
]
)
self.blocks4 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[3],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks4"])
]
)
self.blocks5 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[4],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks5"])
]
)
self.blocks6 = nn.Sequential(
*[
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[5],
lab_lr=lab_lr,
)
for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks6"])
]
)
self.out_channels = make_divisible(512 * scale)
if self.det:
mv_c = [16, 24, 56, 480]
self.out_channels = [
make_divisible(self.net_config["blocks3"][-1][2] * scale),
make_divisible(self.net_config["blocks4"][-1][2] * scale),
make_divisible(self.net_config["blocks5"][-1][2] * scale),
make_divisible(self.net_config["blocks6"][-1][2] * scale),
]
self.layer_list = nn.ModuleList(
[
nn.Conv2d(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
nn.Conv2d(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
nn.Conv2d(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
nn.Conv2d(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
]
)
self.out_channels = [
int(mv_c[0] * scale),
int(mv_c[1] * scale),
int(mv_c[2] * scale),
int(mv_c[3] * scale),
]
def forward(self, x):
out_list = []
x = self.conv1(x)
x = self.blocks2(x)
x = self.blocks3(x)
out_list.append(x)
x = self.blocks4(x)
out_list.append(x)
x = self.blocks5(x)
out_list.append(x)
x = self.blocks6(x)
out_list.append(x)
if self.det:
out_list[0] = self.layer_list[0](out_list[0])
out_list[1] = self.layer_list[1](out_list[1])
out_list[2] = self.layer_list[2](out_list[2])
out_list[3] = self.layer_list[3](out_list[3])
return out_list
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
from torch import nn
from .det_mobilenet_v3 import ConvBNLayer, ResidualUnit, make_divisible
class MobileNetV3(nn.Module):
def __init__(
self,
in_channels=3,
model_name="small",
scale=0.5,
large_stride=None,
small_stride=None,
**kwargs
):
super(MobileNetV3, self).__init__()
if small_stride is None:
small_stride = [2, 2, 2, 2]
if large_stride is None:
large_stride = [1, 2, 2, 2]
assert isinstance(
large_stride, list
), "large_stride type must " "be list but got {}".format(type(large_stride))
assert isinstance(
small_stride, list
), "small_stride type must " "be list but got {}".format(type(small_stride))
assert (
len(large_stride) == 4
), "large_stride length must be " "4 but got {}".format(len(large_stride))
assert (
len(small_stride) == 4
), "small_stride length must be " "4 but got {}".format(len(small_stride))
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", large_stride[0]],
[3, 64, 24, False, "relu", (large_stride[1], 1)],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", (large_stride[2], 1)],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hard_swish", 1],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1],
[5, 672, 160, True, "hard_swish", (large_stride[3], 1)],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", (small_stride[0], 1)],
[3, 72, 24, False, "relu", (small_stride[1], 1)],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hard_swish", (small_stride[2], 1)],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1],
[5, 288, 96, True, "hard_swish", (small_stride[3], 1)],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError(
"mode[" + model_name + "_model] is not implemented!"
)
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert (
scale in supported_scale
), "supported scales are {} but input scale is {}".format(
supported_scale, scale
)
inplanes = 16
# conv1
self.conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act="hard_swish",
name="conv1",
)
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
for k, exp, c, se, nl, s in cfg:
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2),
)
)
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act="hard_swish",
name="conv_last",
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
x = self.pool(x)
return x
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..common import Activation
class ConvBNLayer(nn.Module):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='hard_swish'):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias=False)
self._batch_norm = nn.BatchNorm2d(
num_filters,
)
if self.act is not None:
self._act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class DepthwiseSeparable(nn.Module):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
dw_size=3,
padding=1,
use_se=False):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=dw_size,
stride=stride,
padding=padding,
num_groups=int(num_groups * scale))
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
if self.use_se:
y = self._se(y)
y = self._pointwise_conv(y)
return y
class MobileNetV1Enhance(nn.Module):
def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__()
self.scale = scale
self.block_list = []
self.conv1 = ConvBNLayer(
num_channels=in_channels,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale)
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=1,
scale=scale)
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale)
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=(2, 1),
scale=scale)
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale)
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=(2, 1),
scale=scale)
self.block_list.append(conv4_2)
for _ in range(5):
conv5 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
dw_size=5,
padding=2,
scale=scale,
use_se=False)
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=(2, 1),
dw_size=5,
padding=2,
scale=scale,
use_se=True)
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
scale=scale)
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
def forward(self, inputs):
y = self.conv1(inputs)
y = self.block_list(y)
y = self.pool(y)
return y
def hardsigmoid(x):
return F.relu6(x + 3., inplace=True) / 6.
class SEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.conv2 = nn.Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
bias=True)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs)
x = torch.mul(inputs, outputs)
return x
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .rec_donut_swin import DonutSwinModelOutput
from typing import List, Dict, Union, Callable
class IdentityBasedConv1x1(nn.Conv2d):
def __init__(self, channels, groups=1):
super(IdentityBasedConv1x1, self).__init__(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False,
)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = torch.Tensor(id_value)
self.weight.set_value(torch.zeros_like(self.weight))
def forward(self, input):
kernel = self.weight + self.id_tensor
result = F.conv2d(
input,
kernel,
None,
stride=1,
padding=0,
dilation=self._dilation,
groups=self._groups,
)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor
class BNAndPad(nn.Module):
def __init__(
self,
pad_pixels,
num_features,
epsilon=1e-5,
momentum=0.1,
last_conv_bias=None,
bn=nn.BatchNorm2d,
):
super().__init__()
self.bn = bn(num_features, momentum=momentum, epsilon=epsilon)
self.pad_pixels = pad_pixels
self.last_conv_bias = last_conv_bias
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
bias = -self.bn._mean
if self.last_conv_bias is not None:
bias += self.last_conv_bias
pad_values = self.bn.bias + self.bn.weight * (
bias / torch.sqrt(self.bn._variance + self.bn._epsilon)
)
""" pad """
# TODO: n,h,w,c format is not supported yet
n, c, h, w = output.shape
values = pad_values.reshape([1, -1, 1, 1])
w_values = values.expand([n, -1, self.pad_pixels, w])
x = torch.cat([w_values, output, w_values], dim=2)
h = h + self.pad_pixels * 2
h_values = values.expand([n, -1, h, self.pad_pixels])
x = torch.cat([h_values, x, h_values], dim=3)
output = x
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def _mean(self):
return self.bn._mean
@property
def _variance(self):
return self.bn._variance
@property
def _epsilon(self):
return self.bn._epsilon
def conv_bn(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
):
conv_layer = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=False,
padding_mode=padding_mode,
)
bn_layer = nn.BatchNorm2D(num_features=out_channels)
se = nn.Sequential()
se.add_sublayer("conv", conv_layer)
se.add_sublayer("bn", bn_layer)
return se
def transI_fusebn(kernel, bn):
gamma = bn.weight
std = (bn._variance + bn._epsilon).sqrt()
return (
kernel * ((gamma / std).reshape([-1, 1, 1, 1])),
bn.bias - bn._mean * gamma / std,
)
def transII_addbranch(kernels, biases):
return sum(kernels), sum(biases)
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
if groups == 1:
k = F.conv2d(k2, k1.transpose([1, 0, 2, 3]))
b_hat = (k2 * b1.reshape([1, -1, 1, 1])).sum((1, 2, 3))
else:
k_slices = []
b_slices = []
k1_T = k1.transpose([1, 0, 2, 3])
k1_group_width = k1.shape[0] // groups
k2_group_width = k2.shape[0] // groups
for g in range(groups):
k1_T_slice = k1_T[:, g * k1_group_width : (g + 1) * k1_group_width, :, :]
k2_slice = k2[g * k2_group_width : (g + 1) * k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append(
(
k2_slice
* b1[g * k1_group_width : (g + 1) * k1_group_width].reshape(
[1, -1, 1, 1]
)
).sum((1, 2, 3))
)
k, b_hat = transIV_depthconcat(k_slices, b_slices)
return k, b_hat + b2
def transIV_depthconcat(kernels, biases):
return torch.cat(kernels, dim=0), torch.cat(biases)
def transV_avg(channels, kernel_size, groups):
input_dim = channels // groups
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = (
1.0 / kernel_size**2
)
return k
def transVI_multiscale(kernel, target_kernel_size):
H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2
W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2
return F.pad(
kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]
)
class DiverseBranchBlock(nn.Module):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
is_repped=False,
single_init=False,
**kwargs,
):
super().__init__()
padding = (filter_size - 1) // 2
dilation = 1
in_channels = num_channels
out_channels = num_filters
kernel_size = filter_size
internal_channels_1x1_3x3 = None
nonlinear = act
self.is_repped = is_repped
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nn.ReLU()
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if is_repped:
self.dbb_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
)
else:
self.dbb_origin = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_sublayer(
"conv",
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False,
),
)
self.dbb_avg.add_sublayer(
"bn", BNAndPad(pad_pixels=padding, num_features=out_channels)
)
self.dbb_avg.add_sublayer(
"avg",
nn.AvgPool2D(kernel_size=kernel_size, stride=stride, padding=0),
)
self.dbb_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
groups=groups,
)
else:
self.dbb_avg.add_sublayer(
"avg",
nn.AvgPool2D(
kernel_size=kernel_size, stride=stride, padding=padding
),
)
self.dbb_avg.add_sublayer("avgbn", nn.BatchNorm2D(out_channels))
if internal_channels_1x1_3x3 is None:
internal_channels_1x1_3x3 = (
in_channels if groups < out_channels else 2 * in_channels
) # For mobilenet, it is better to have 2X internal channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_sublayer(
"idconv1", IdentityBasedConv1x1(channels=in_channels, groups=groups)
)
else:
self.dbb_1x1_kxk.add_sublayer(
"conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False,
),
)
self.dbb_1x1_kxk.add_sublayer(
"bn1",
BNAndPad(pad_pixels=padding, num_features=internal_channels_1x1_3x3),
)
self.dbb_1x1_kxk.add_sublayer(
"conv2",
nn.Conv2d(
in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias=False,
),
)
self.dbb_1x1_kxk.add_sublayer("bn2", nn.BatchNorm2D(out_channels))
# The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
if single_init:
# Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
self.single_init()
def forward(self, inputs):
if self.is_repped:
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, "dbb_1x1"):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return self.nonlinear(out)
def init_gamma(self, gamma_value):
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
if hasattr(self, "dbb_1x1"):
torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
if hasattr(self, "dbb_avg"):
torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
if hasattr(self, "dbb_1x1_kxk"):
torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
def single_init(self):
self.init_gamma(0.0)
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
def get_equivalent_kernel_bias(self):
k_origin, b_origin = transI_fusebn(
self.dbb_origin.conv.weight, self.dbb_origin.bn
)
if hasattr(self, "dbb_1x1"):
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
else:
k_1x1, b_1x1 = 0, 0
if hasattr(self.dbb_1x1_kxk, "idconv1"):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(
k_1x1_kxk_first, self.dbb_1x1_kxk.bn1
)
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(
self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2
)
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(
k_1x1_kxk_first,
b_1x1_kxk_first,
k_1x1_kxk_second,
b_1x1_kxk_second,
groups=self.groups,
)
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg, self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, "conv"):
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(
self.dbb_avg.conv.weight, self.dbb_avg.bn
)
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(
k_1x1_avg_first,
b_1x1_avg_first,
k_1x1_avg_second,
b_1x1_avg_second,
groups=self.groups,
)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
return transII_addbranch(
(k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
(b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged),
)
def re_parameterize(self):
if self.is_repped:
return
kernel, bias = self.get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2d(
in_channels=self.dbb_origin.conv._in_channels,
out_channels=self.dbb_origin.conv._out_channels,
kernel_size=self.dbb_origin.conv._kernel_size,
stride=self.dbb_origin.conv._stride,
padding=self.dbb_origin.conv._padding,
dilation=self.dbb_origin.conv._dilation,
groups=self.dbb_origin.conv._groups,
bias=True,
)
self.dbb_reparam.weight.set_value(kernel)
self.dbb_reparam.bias.set_value(bias)
self.__delattr__("dbb_origin")
self.__delattr__("dbb_avg")
if hasattr(self, "dbb_1x1"):
self.__delattr__("dbb_1x1")
self.__delattr__("dbb_1x1_kxk")
self.is_repped = True
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, inputs):
return inputs
class TheseusLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.res_dict = {}
# self.res_name = self.full_name()
self.res_name = self.__class__.__name__.lower()
self.pruner = None
self.quanter = None
self.init_net(*args, **kwargs)
def _return_dict_hook(self, layer, input, output):
res_dict = {"logits": output}
# 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict):
# clear the res_dict because the forward process may change according to input
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def init_net(
self,
stages_pattern=None,
return_patterns=None,
return_stages=None,
freeze_befor=None,
stop_after=None,
*args,
**kwargs,
):
# init the output of net
if return_patterns or return_stages:
if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
return_stages = [
val
for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
# call update_res function after the __init__ of the object has completed execution, that is, the constructing of layer or model has been completed.
def update_res_hook(layer, input):
self.update_res(return_patterns)
self.register_forward_pre_hook(update_res_hook)
# freeze subnet
if freeze_befor is not None:
self.freeze_befor(freeze_befor)
# set subnet to Identity
if stop_after is not None:
self.stop_after(stop_after)
def init_res(self, stages_pattern, return_patterns=None, return_stages=None):
if return_patterns and return_stages:
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
return_stages = [
val
for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
self.update_res(return_patterns)
def replace_sub(self, *args, **kwargs) -> None:
msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
raise DeprecationWarning(msg)
def upgrade_sublayer(
self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Module, str], nn.Module],
) -> Dict[str, nn.Module]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Module, str], nn.Module]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Module) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
Returns:
Dict[str, nn.Module]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
Examples:
from paddle import nn
import paddleclas
def rep_func(layer: nn.Module, pattern: str):
new_layer = nn.Conv2d(
in_channels=layer._in_channels,
out_channels=layer._out_channels,
kernel_size=5,
padding=2
)
return new_layer
net = paddleclas.MobileNetV1()
res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
hit_layer_pattern_list = []
for pattern in layer_name_pattern:
# parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list:
continue
sub_layer_parent = layer_list[-2]["layer"] if len(layer_list) > 1 else self
sub_layer = layer_list[-1]["layer"]
sub_layer_name = layer_list[-1]["name"]
sub_layer_index_list = layer_list[-1]["index_list"]
new_sub_layer = handle_func(sub_layer, pattern)
if sub_layer_index_list:
if len(sub_layer_index_list) > 1:
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[
sub_layer_index_list[0]
]
for sub_layer_index in sub_layer_index_list[1:-1]:
sub_layer_parent = sub_layer_parent[sub_layer_index]
sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
else:
getattr(sub_layer_parent, sub_layer_name)[
sub_layer_index_list[0]
] = new_sub_layer
else:
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
hit_layer_pattern_list.append(pattern)
return hit_layer_pattern_list
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
Args:
stop_layer_name (str): The name of layer that stop forward and backward after this layer.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
layer_list = parse_pattern_str(stop_layer_name, self)
if not layer_list:
return False
parent_layer = self
for layer_dict in layer_list:
name, index_list = layer_dict["name"], layer_dict["index_list"]
if not set_identity(parent_layer, name, index_list):
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
return False
parent_layer = layer_dict["layer"]
return True
def freeze_befor(self, layer_name: str) -> bool:
"""freeze the layer named layer_name and its previous layer.
Args:
layer_name (str): The name of layer that would be freezed.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
def stop_grad(layer, pattern):
class StopGradLayer(nn.Module):
def __init__(self):
super().__init__()
self.layer = layer
def forward(self, x):
x = self.layer(x)
x.stop_gradient = True
return x
new_layer = StopGradLayer()
return new_layer
res = self.upgrade_sublayer(layer_name, stop_grad)
if len(res) == 0:
msg = "Failed to stop the gradient before the layer named '{layer_name}'"
return False
return True
def update_res(self, return_patterns: Union[str, List[str]]) -> Dict[str, nn.Module]:
"""update the result(s) to be returned.
Args:
return_patterns (Union[str, List[str]]): The name of layer to return output.
Returns:
Dict[str, nn.Module]: The pattern(str) and corresponding layer(nn.Module) that have been set successfully.
"""
# clear res_dict that could have been set
self.res_dict = {}
class Handler(object):
def __init__(self, res_dict):
# res_dict is a reference
self.res_dict = res_dict
def __call__(self, layer, pattern):
layer.res_dict = self.res_dict
layer.res_name = pattern
if hasattr(layer, "hook_remove_helper"):
layer.hook_remove_helper.remove()
layer.hook_remove_helper = layer.register_forward_post_hook(
save_sub_res_hook
)
return layer
handle_func = Handler(self.res_dict)
hit_layer_pattern_list = self.upgrade_sublayer(
return_patterns, handle_func=handle_func
)
if hasattr(self, "hook_remove_helper"):
self.hook_remove_helper.remove()
self.hook_remove_helper = self.register_forward_post_hook(
self._return_dict_hook
)
return hit_layer_pattern_list
def save_sub_res_hook(layer, input, output):
layer.res_dict[layer.res_name] = output
def set_identity(
parent_layer: nn.Module, layer_name: str, layer_index_list: str = None
) -> bool:
"""set the layer specified by layer_name and layer_index_list to Identity.
Args:
parent_layer (nn.Module): The parent layer of target layer specified by layer_name and layer_index_list.
layer_name (str): The name of target layer to be set to Identity.
layer_index_list (str, optional): The index of target layer to be set to Identity in parent_layer. Defaults to None.
Returns:
bool: True if successfully, False otherwise.
"""
stop_after = False
for sub_layer_name in parent_layer._sub_layers:
if stop_after:
parent_layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index_list and stop_after:
layer_container = parent_layer._sub_layers[layer_name]
for num, layer_index in enumerate(layer_index_list):
stop_after = False
for i in range(num):
layer_container = layer_container[layer_index_list[i]]
for sub_layer_index in layer_container._sub_layers:
if stop_after:
parent_layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def parse_pattern_str(
pattern: str, parent_layer: nn.Module
) -> Union[None, List[Dict[str, Union[nn.Module, str, None]]]]:
"""parse the string type pattern.
Args:
pattern (str): The pattern to describe layer.
parent_layer (nn.Module): The root layer relative to the pattern.
Returns:
Union[None, List[Dict[str, Union[nn.Module, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
...
]
"""
pattern_list = pattern.split(".")
if not pattern_list:
msg = f"The pattern('{pattern}') is illegal. Please check and retry."
return None
layer_list = []
while len(pattern_list) > 0:
if "[" in pattern_list[0]:
target_layer_name = pattern_list[0].split("[")[0]
target_layer_index_list = list(
index.split("]")[0] for index in pattern_list[0].split("[")[1:]
)
else:
target_layer_name = pattern_list[0]
target_layer_index_list = None
target_layer = getattr(parent_layer, target_layer_name, None)
if target_layer is None:
msg = f"Not found layer named('{target_layer_name}') specified in pattern('{pattern}')."
return None
if target_layer_index_list:
for target_layer_index in target_layer_index_list:
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
target_layer
):
msg = f"Not found layer by index('{target_layer_index}') specified in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
return None
target_layer = target_layer[target_layer_index]
layer_list.append(
{
"layer": target_layer,
"name": target_layer_name,
"index_list": target_layer_index_list,
}
)
pattern_list = pattern_list[1:]
parent_layer = target_layer
return layer_list
class LearnableAffineBlock(TheseusLayer):
"""
Create a learnable affine block module. This module can significantly improve accuracy on smaller models.
Args:
scale_value (float): The initial value of the scale parameter, default is 1.0.
bias_value (float): The initial value of the bias parameter, default is 0.0.
lr_mult (float): The learning rate multiplier, default is 1.0.
lab_lr (float): The learning rate, default is 0.01.
"""
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01):
super().__init__()
# self.scale = self.create_parameter(
# shape=[
# 1,
# ],
# default_initializer=nn.init.Constant(value=scale_value),
# # attr=ParamAttr(learning_rate=lr_mult * lab_lr),
# )
# self.add_parameter("scale", self.scale)
self.scale = torch.Parameter(
nn.init.constant_(
torch.ones(1).to(torch.float32), val=scale_value
)
)
self.register_parameter("scale", self.scale)
# self.bias = self.create_parameter(
# shape=[
# 1,
# ],
# default_initializer=nn.init.Constant(value=bias_value),
# # attr=ParamAttr(learning_rate=lr_mult * lab_lr),
# )
# self.add_parameter("bias", self.bias)
self.bias = torch.Parameter(
nn.init.constant_(
torch.ones(1).to(torch.float32), val=bias_value
)
)
self.register_parameter("bias", self.bias)
def forward(self, x):
return self.scale * x + self.bias
class ConvBNAct(TheseusLayer):
"""
ConvBNAct is a combination of convolution and batchnorm layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
stride (int): Stride of the convolution. Defaults to 1.
padding (int/str): Padding or padding type for the convolution. Defaults to 1.
groups (int): Number of groups for the convolution. Defaults to 1.
use_act: (bool): Whether to use activation function. Defaults to True.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=1,
use_act=True,
use_lab=False,
lr_mult=1.0,
):
super().__init__()
self.use_act = use_act
self.use_lab = use_lab
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(
out_channels,
)
if self.use_act:
self.act = nn.ReLU()
if self.use_lab:
self.lab = LearnableAffineBlock(lr_mult=lr_mult)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
if self.use_lab:
x = self.lab(x)
return x
class LightConvBNAct(TheseusLayer):
"""
LightConvBNAct is a combination of pw and dw layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the depth-wise convolution kernel.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
use_lab=False,
lr_mult=1.0,
**kwargs,
):
super().__init__()
self.conv1 = ConvBNAct(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
use_act=False,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.conv2 = ConvBNAct(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=out_channels,
use_act=True,
use_lab=use_lab,
lr_mult=lr_mult,
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class PaddingSameAsPaddleMaxPool2d(torch.nn.Module):
def __init__(self, kernel_size, stride=1):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.pool = torch.nn.MaxPool2d(kernel_size, stride, padding=0, ceil_mode=True)
def forward(self, x):
_, _, h, w = x.shape
pad_h_total = max(0, (math.ceil(h / self.stride) - 1) * self.stride + self.kernel_size - h)
pad_w_total = max(0, (math.ceil(w / self.stride) - 1) * self.stride + self.kernel_size - w)
pad_h = pad_h_total // 2
pad_w = pad_w_total // 2
x = torch.nn.functional.pad(x, [pad_w, pad_w_total - pad_w, pad_h, pad_h_total - pad_h])
return self.pool(x)
class StemBlock(TheseusLayer):
"""
StemBlock for PP-HGNetV2.
Args:
in_channels (int): Number of input channels.
mid_channels (int): Number of middle channels.
out_channels (int): Number of output channels.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
mid_channels,
out_channels,
use_lab=False,
lr_mult=1.0,
text_rec=False,
):
super().__init__()
self.stem1 = ConvBNAct(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
stride=2,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem2a = ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels // 2,
kernel_size=2,
stride=1,
padding="same",
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem2b = ConvBNAct(
in_channels=mid_channels // 2,
out_channels=mid_channels,
kernel_size=2,
stride=1,
padding="same",
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem3 = ConvBNAct(
in_channels=mid_channels * 2,
out_channels=mid_channels,
kernel_size=3,
stride=1 if text_rec else 2,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.stem4 = ConvBNAct(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.pool = PaddingSameAsPaddleMaxPool2d(
kernel_size=2, stride=1,
)
def forward(self, x):
x = self.stem1(x)
x2 = self.stem2a(x)
x2 = self.stem2b(x2)
x1 = self.pool(x)
x = torch.cat([x1, x2], 1)
x = self.stem3(x)
x = self.stem4(x)
return x
class HGV2_Block(TheseusLayer):
"""
HGV2_Block, the basic unit that constitutes the HGV2_Stage.
Args:
in_channels (int): Number of input channels.
mid_channels (int): Number of middle channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
stride (int): Stride of the convolution. Defaults to 1.
padding (int/str): Padding or padding type for the convolution. Defaults to 1.
groups (int): Number of groups for the convolution. Defaults to 1.
use_act (bool): Whether to use activation function. Defaults to True.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
mid_channels,
out_channels,
kernel_size=3,
layer_num=6,
identity=False,
light_block=True,
use_lab=False,
lr_mult=1.0,
):
super().__init__()
self.identity = identity
self.layers = nn.ModuleList()
block_type = "LightConvBNAct" if light_block else "ConvBNAct"
for i in range(layer_num):
self.layers.append(
eval(block_type)(
in_channels=in_channels if i == 0 else mid_channels,
out_channels=mid_channels,
stride=1,
kernel_size=kernel_size,
use_lab=use_lab,
lr_mult=lr_mult,
)
)
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
self.aggregation_squeeze_conv = ConvBNAct(
in_channels=total_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
use_lab=use_lab,
lr_mult=lr_mult,
)
self.aggregation_excitation_conv = ConvBNAct(
in_channels=out_channels // 2,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_lab=use_lab,
lr_mult=lr_mult,
)
def forward(self, x):
identity = x
output = []
output.append(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
x = self.aggregation_squeeze_conv(x)
x = self.aggregation_excitation_conv(x)
if self.identity:
x += identity
return x
class HGV2_Stage(TheseusLayer):
"""
HGV2_Stage, the basic unit that constitutes the PPHGNetV2.
Args:
in_channels (int): Number of input channels.
mid_channels (int): Number of middle channels.
out_channels (int): Number of output channels.
block_num (int): Number of blocks in the HGV2 stage.
layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
is_downsample (bool): Whether to use downsampling operation. Defaults to False.
light_block (bool): Whether to use light block. Defaults to True.
kernel_size (int): Size of the convolution kernel. Defaults to 3.
use_lab (bool, optional): Whether to use the LAB operation. Defaults to False.
lr_mult (float, optional): Learning rate multiplier for the layer. Defaults to 1.0.
"""
def __init__(
self,
in_channels,
mid_channels,
out_channels,
block_num,
layer_num=6,
is_downsample=True,
light_block=True,
kernel_size=3,
use_lab=False,
stride=2,
lr_mult=1.0,
):
super().__init__()
self.is_downsample = is_downsample
if self.is_downsample:
self.downsample = ConvBNAct(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
groups=in_channels,
use_act=False,
use_lab=use_lab,
lr_mult=lr_mult,
)
blocks_list = []
for i in range(block_num):
blocks_list.append(
HGV2_Block(
in_channels=in_channels if i == 0 else out_channels,
mid_channels=mid_channels,
out_channels=out_channels,
kernel_size=kernel_size,
layer_num=layer_num,
identity=False if i == 0 else True,
light_block=light_block,
use_lab=use_lab,
lr_mult=lr_mult,
)
)
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
if self.is_downsample:
x = self.downsample(x)
x = self.blocks(x)
return x
class PPHGNetV2(TheseusLayer):
"""
PPHGNetV2
Args:
stage_config (dict): Config for PPHGNetV2 stages. such as the number of channels, stride, etc.
stem_channels: (list): Number of channels of the stem of the PPHGNetV2.
use_lab (bool): Whether to use the LAB operation. Defaults to False.
use_last_conv (bool): Whether to use the last conv layer as the output channel. Defaults to True.
class_expand (int): Number of channels for the last 1x1 convolutional layer.
drop_prob (float): Dropout probability for the last 1x1 convolutional layer. Defaults to 0.0.
class_num (int): The number of classes for the classification layer. Defaults to 1000.
lr_mult_list (list): Learning rate multiplier for the stages. Defaults to [1.0, 1.0, 1.0, 1.0, 1.0].
Returns:
model: nn.Module. Specific PPHGNetV2 model depends on args.
"""
def __init__(
self,
stage_config,
stem_channels=[3, 32, 64],
use_lab=False,
use_last_conv=True,
class_expand=2048,
dropout_prob=0.0,
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
det=False,
text_rec=False,
out_indices=None,
**kwargs,
):
super().__init__()
self.det = det
self.text_rec = text_rec
self.use_lab = use_lab
self.use_last_conv = use_last_conv
self.class_expand = class_expand
self.class_num = class_num
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
self.out_channels = []
# stem
self.stem = StemBlock(
in_channels=stem_channels[0],
mid_channels=stem_channels[1],
out_channels=stem_channels[2],
use_lab=use_lab,
lr_mult=lr_mult_list[0],
text_rec=text_rec,
)
# stages
self.stages = nn.ModuleList()
for i, k in enumerate(stage_config):
(
in_channels,
mid_channels,
out_channels,
block_num,
is_downsample,
light_block,
kernel_size,
layer_num,
stride,
) = stage_config[k]
self.stages.append(
HGV2_Stage(
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
is_downsample,
light_block,
kernel_size,
use_lab,
stride,
lr_mult=lr_mult_list[i + 1],
)
)
if i in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if self.use_last_conv:
self.last_conv = nn.Conv2d(
in_channels=out_channels,
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.act = nn.ReLU()
if self.use_lab:
self.lab = LearnableAffineBlock()
# self.dropout = nn.Dropout(p=dropout_prob, mode="downscale_in_infer")
self.dropout = nn.Dropout(p=dropout_prob)
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
if not self.det:
self.fc = nn.Linear(
self.class_expand if self.use_last_conv else out_channels,
self.class_num,
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
out = []
for i, stage in enumerate(self.stages):
x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.text_rec:
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
def PPHGNetV2_B0(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B0
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B0` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [16, 16, 64, 1, False, False, 3, 3],
"stage2": [64, 32, 256, 1, True, False, 3, 3],
"stage3": [256, 64, 512, 2, True, True, 5, 3],
"stage4": [512, 128, 1024, 1, True, True, 5, 3],
}
model = PPHGNetV2(
stem_channels=[3, 16, 16], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B1(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B1
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B1` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 64, 1, False, False, 3, 3],
"stage2": [64, 48, 256, 1, True, False, 3, 3],
"stage3": [256, 96, 512, 2, True, True, 5, 3],
"stage4": [512, 192, 1024, 1, True, True, 5, 3],
}
model = PPHGNetV2(
stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B2(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B2
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B2` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 96, 1, False, False, 3, 4],
"stage2": [96, 64, 384, 1, True, False, 3, 4],
"stage3": [384, 128, 768, 3, True, True, 5, 4],
"stage4": [768, 256, 1536, 1, True, True, 5, 4],
}
model = PPHGNetV2(
stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B3
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B3` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 128, 1, False, False, 3, 5],
"stage2": [128, 64, 512, 1, True, False, 3, 5],
"stage3": [512, 128, 1024, 3, True, True, 5, 5],
"stage4": [1024, 256, 2048, 1, True, True, 5, 5],
}
model = PPHGNetV2(
stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
)
return model
def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
"""
PPHGNetV2_B4
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B4` model depends on args.
"""
stage_config_rec = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
"stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
"stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
}
stage_config_det = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
}
model = PPHGNetV2(
stem_channels=[3, 32, 48],
stage_config=stage_config_det if det else stage_config_rec,
use_lab=False,
det=det,
text_rec=text_rec,
**kwargs,
)
return model
def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B5
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B5` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [64, 64, 128, 1, False, False, 3, 6],
"stage2": [128, 128, 512, 2, True, False, 3, 6],
"stage3": [512, 256, 1024, 5, True, True, 5, 6],
"stage4": [1024, 512, 2048, 2, True, True, 5, 6],
}
model = PPHGNetV2(
stem_channels=[3, 32, 64], stage_config=stage_config, use_lab=False, **kwargs
)
return model
def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B6
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Module. Specific `PPHGNetV2_B6` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [96, 96, 192, 2, False, False, 3, 6],
"stage2": [192, 192, 512, 3, True, False, 3, 6],
"stage3": [512, 384, 1024, 6, True, True, 5, 6],
"stage4": [1024, 768, 2048, 3, True, True, 5, 6],
}
model = PPHGNetV2(
stem_channels=[3, 48, 96], stage_config=stage_config, use_lab=False, **kwargs
)
return model
class PPHGNetV2_B4_Formula(nn.Module):
"""
PPHGNetV2_B4_Formula
Args:
in_channels (int): Number of input channels. Default is 3 (for RGB images).
class_num (int): Number of classes for classification. Default is 1000.
Returns:
model: nn.Module. Specific `PPHGNetV2_B4` model with defined architecture.
"""
def __init__(self, in_channels=3, class_num=1000):
super().__init__()
self.in_channels = in_channels
self.out_channels = 2048
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
}
self.pphgnet_b4 = PPHGNetV2(
stem_channels=[3, 32, 48],
stage_config=stage_config,
class_num=class_num,
use_lab=False,
)
def forward(self, input_data):
if self.training:
pixel_values, label, attention_mask = input_data
else:
if isinstance(input_data, list):
pixel_values = input_data[0]
else:
pixel_values = input_data
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
pphgnet_b4_output = self.pphgnet_b4(pixel_values)
b, c, h, w = pphgnet_b4_output.shape
pphgnet_b4_output = pphgnet_b4_output.reshape([b, c, h * w]).transpose(
[0, 2, 1]
)
pphgnet_b4_output = DonutSwinModelOutput(
last_hidden_state=pphgnet_b4_output,
pooler_output=None,
hidden_states=None,
attentions=False,
reshaped_hidden_states=None,
)
if self.training:
return pphgnet_b4_output, label, attention_mask
else:
return pphgnet_b4_output
class PPHGNetV2_B6_Formula(nn.Module):
"""
PPHGNetV2_B6_Formula
Args:
in_channels (int): Number of input channels. Default is 3 (for RGB images).
class_num (int): Number of classes for classification. Default is 1000.
Returns:
model: nn.Module. Specific `PPHGNetV2_B6` model with defined architecture.
"""
def __init__(self, in_channels=3, class_num=1000):
super().__init__()
self.in_channels = in_channels
self.out_channels = 2048
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [96, 96, 192, 2, False, False, 3, 6, 2],
"stage2": [192, 192, 512, 3, True, False, 3, 6, 2],
"stage3": [512, 384, 1024, 6, True, True, 5, 6, 2],
"stage4": [1024, 768, 2048, 3, True, True, 5, 6, 2],
}
self.pphgnet_b6 = PPHGNetV2(
stem_channels=[3, 48, 96],
class_num=class_num,
stage_config=stage_config,
use_lab=False,
)
def forward(self, input_data):
if self.training:
pixel_values, label, attention_mask = input_data
else:
if isinstance(input_data, list):
pixel_values = input_data[0]
else:
pixel_values = input_data
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
pphgnet_b6_output = self.pphgnet_b6(pixel_values)
b, c, h, w = pphgnet_b6_output.shape
pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).transpose(
[0, 2, 1]
)
pphgnet_b6_output = DonutSwinModelOutput(
last_hidden_state=pphgnet_b6_output,
pooler_output=None,
hidden_states=None,
attentions=False,
reshaped_hidden_states=None,
)
if self.training:
return pphgnet_b6_output, label, attention_mask
else:
return pphgnet_b6_output
import numpy as np
import torch
from torch import nn
from ..common import Activation
def drop_path(x, drop_prob=0.0, training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = torch.as_tensor(1 - drop_prob)
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
random_tensor = torch.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class ConvBNLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act="gelu",
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias_attr,
)
self.norm = nn.BatchNorm2d(out_channels)
self.act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer="gelu",
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = Activation(act_type=act_layer, inplace=True)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
num_heads=8,
HW=[8, 25],
local_k=[3, 3],
):
super().__init__()
self.HW = HW
self.dim = dim
self.local_mixer = nn.Conv2d(
dim,
dim,
local_k,
1,
[local_k[0] // 2, local_k[1] // 2],
groups=num_heads,
)
def forward(self, x):
h = self.HW[0]
w = self.HW[1]
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
x = self.local_mixer(x)
x = x.flatten(2).permute(0, 2, 1)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
mixer="Global",
HW=[8, 25],
local_k=[7, 11],
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.HW = HW
if HW is not None:
H = HW[0]
W = HW[1]
self.N = H * W
self.C = dim
if mixer == "Local" and HW is not None:
hk = local_k[0]
wk = local_k[1]
mask = torch.ones(H * W, H + hk - 1, W + wk - 1, dtype=torch.float32)
for h in range(0, H):
for w in range(0, W):
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
1
)
mask_inf = torch.full(
[H * W, H * W], fill_value=float("-Inf"), dtype=torch.float32
)
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask.unsqueeze(0).unsqueeze(1)
# self.mask = mask[None, None, :]
self.mixer = mixer
def forward(self, x):
if self.HW is not None:
N = self.N
C = self.C
else:
_, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute(
2, 0, 3, 1, 4
)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q.matmul(k.permute(0, 1, 3, 2))
if self.mixer == "Local":
attn += self.mask
attn = nn.functional.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).permute(0, 2, 1, 3).reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mixer="Global",
local_mixer=[7, 11],
HW=None,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer="gelu",
norm_layer="nn.LayerNorm",
epsilon=1e-6,
prenorm=True,
):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm1 = norm_layer(dim)
if mixer == "Global" or mixer == "Local":
self.mixer = Attention(
dim,
num_heads=num_heads,
mixer=mixer,
HW=HW,
local_k=local_mixer,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
elif mixer == "Conv":
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
else:
raise TypeError("The mixer must be one of [Global, Local, Conv]")
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.prenorm = prenorm
def forward(self, x):
if self.prenorm:
x = self.norm1(x + self.drop_path(self.mixer(x)))
x = self.norm2(x + self.drop_path(self.mlp(x)))
else:
x = x + self.drop_path(self.mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
img_size=[32, 100],
in_channels=3,
embed_dim=768,
sub_num=2,
patch_size=[4, 4],
mode="pope",
):
super().__init__()
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
self.img_size = img_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
if mode == "pope":
if sub_num == 2:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
)
if sub_num == 3:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 4,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
ConvBNLayer(
in_channels=embed_dim // 4,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act="gelu",
bias_attr=True,
),
)
elif mode == "linear":
self.proj = nn.Conv2d(
1, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.num_patches = (
img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
)
def forward(self, x):
B, C, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), "Input image size ({}*{}) doesn't match model ({}*{}).".format(
H, W, self.img_size[0], self.img_size[1]
)
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
class SubSample(nn.Module):
def __init__(
self,
in_channels,
out_channels,
types="Pool",
stride=[2, 1],
sub_norm="nn.LayerNorm",
act=None,
):
super().__init__()
self.types = types
if types == "Pool":
self.avgpool = nn.AvgPool2d(
kernel_size=[3, 5], stride=stride, padding=[1, 2]
)
self.maxpool = nn.MaxPool2d(
kernel_size=[3, 5], stride=stride, padding=[1, 2]
)
self.proj = nn.Linear(in_channels, out_channels)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
)
self.norm = eval(sub_norm)(out_channels)
if act is not None:
self.act = act()
else:
self.act = None
def forward(self, x):
if self.types == "Pool":
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = (x1 + x2) * 0.5
out = self.proj(x.flatten(2).permute(0, 2, 1))
else:
x = self.conv(x)
out = x.flatten(2).permute(0, 2, 1)
out = self.norm(out)
if self.act is not None:
out = self.act(out)
return out
class SVTRNet(nn.Module):
def __init__(
self,
img_size=[32, 100],
in_channels=3,
embed_dim=[64, 128, 256],
depth=[3, 6, 3],
num_heads=[2, 4, 8],
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
local_mixer=[[7, 11], [7, 11], [7, 11]],
patch_merging="Conv", # Conv, Pool, None
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
last_drop=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer="nn.LayerNorm",
sub_norm="nn.LayerNorm",
epsilon=1e-6,
out_channels=192,
out_char_num=25,
block_unit="Block",
act="gelu",
last_stage=True,
sub_num=2,
prenorm=True,
use_lenhead=False,
**kwargs
):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = out_channels
self.prenorm = prenorm
patch_merging = (
None
if patch_merging != "Conv" and patch_merging != "Pool"
else patch_merging
)
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim[0],
sub_num=sub_num,
)
num_patches = self.patch_embed.num_patches
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
self.pos_drop = nn.Dropout(p=drop_rate)
Block_unit = eval(block_unit)
dpr = np.linspace(0, drop_path_rate, sum(depth))
self.blocks1 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[0],
num_heads=num_heads[0],
mixer=mixer[0 : depth[0]][i],
HW=self.HW,
local_mixer=local_mixer[0],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act,
attn_drop=attn_drop_rate,
drop_path=dpr[0 : depth[0]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[0])
]
)
if patch_merging is not None:
self.sub_sample1 = SubSample(
embed_dim[0],
embed_dim[1],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging,
)
HW = [self.HW[0] // 2, self.HW[1]]
else:
HW = self.HW
self.patch_merging = patch_merging
self.blocks2 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[1],
num_heads=num_heads[1],
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
HW=HW,
local_mixer=local_mixer[1],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act,
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[1])
]
)
if patch_merging is not None:
self.sub_sample2 = SubSample(
embed_dim[1],
embed_dim[2],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging,
)
HW = [self.HW[0] // 4, self.HW[1]]
else:
HW = self.HW
self.blocks3 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[2],
num_heads=num_heads[2],
mixer=mixer[depth[0] + depth[1] :][i],
HW=HW,
local_mixer=local_mixer[2],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act,
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] + depth[1] :][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[2])
]
)
self.last_stage = last_stage
if last_stage:
self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
self.last_conv = nn.Conv2d(
in_channels=embed_dim[2],
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.hardswish = Activation("hard_swish", inplace=True) # nn.Hardswish()
# self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
self.dropout = nn.Dropout(p=last_drop)
if not prenorm:
self.norm = eval(norm_layer)(embed_dim[-1], eps=epsilon)
self.use_lenhead = use_lenhead
if use_lenhead:
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
self.hardswish_len = Activation(
"hard_swish", inplace=True
) # nn.Hardswish()
self.dropout_len = nn.Dropout(p=last_drop)
torch.nn.init.xavier_normal_(self.pos_embed)
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward_features(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample1(
x.permute(0, 2, 1).reshape(
[-1, self.embed_dim[0], self.HW[0], self.HW[1]]
)
)
for blk in self.blocks2:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample2(
x.permute(0, 2, 1).reshape(
[-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
)
)
for blk in self.blocks3:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.use_lenhead:
len_x = self.len_conv(x.mean(1))
len_x = self.dropout_len(self.hardswish_len(len_x))
if self.last_stage:
if self.patch_merging is not None:
h = self.HW[0] // 4
else:
h = self.HW[0]
x = self.avg_pool(
x.permute(0, 2, 1).reshape([-1, self.embed_dim[2], h, self.HW[1]])
)
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
if self.use_lenhead:
return x, len_x
return x
import torch
import torch.nn.functional as F
from torch import nn
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
# out = max(0, min(1, slop*x+offset))
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
class GELU(nn.Module):
def __init__(self, inplace=True):
super(GELU, self).__init__()
self.inplace = inplace
def forward(self, x):
return torch.nn.functional.gelu(x)
class Swish(nn.Module):
def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
if self.inplace:
x.mul_(torch.sigmoid(x))
return x
else:
return x * torch.sigmoid(x)
class Activation(nn.Module):
def __init__(self, act_type, inplace=True):
super(Activation, self).__init__()
act_type = act_type.lower()
if act_type == "relu":
self.act = nn.ReLU(inplace=inplace)
elif act_type == "relu6":
self.act = nn.ReLU6(inplace=inplace)
elif act_type == "sigmoid":
raise NotImplementedError
elif act_type == "hard_sigmoid":
self.act = Hsigmoid(
inplace
) # nn.Hardsigmoid(inplace=inplace)#Hsigmoid(inplace)#
elif act_type == "hard_swish" or act_type == "hswish":
self.act = Hswish(inplace=inplace)
elif act_type == "leakyrelu":
self.act = nn.LeakyReLU(inplace=inplace)
elif act_type == "gelu":
self.act = GELU(inplace=inplace)
elif act_type == "swish":
self.act = Swish(inplace=inplace)
else:
raise NotImplementedError
def forward(self, inputs):
return self.act(inputs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["build_head"]
def build_head(config, **kwargs):
# det head
from .det_db_head import DBHead, PFHeadLocal
# rec head
from .rec_ctc_head import CTCHead
from .rec_multi_head import MultiHead
# cls head
from .cls_head import ClsHead
support_dict = [
"DBHead",
"CTCHead",
"ClsHead",
"MultiHead",
"PFHeadLocal",
]
module_name = config.pop("name")
char_num = config.pop("char_num", 6625)
assert module_name in support_dict, Exception(
"head only support {}".format(support_dict)
)
module_class = eval(module_name)(**config, **kwargs)
return module_class
import torch
import torch.nn.functional as F
from torch import nn
class ClsHead(nn.Module):
"""
Class orientation
Args:
params(dict): super parameters for build Class network
"""
def __init__(self, in_channels, class_dim, **kwargs):
super(ClsHead, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(in_channels, class_dim, bias=True)
def forward(self, x):
x = self.pool(x)
x = torch.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x)
x = F.softmax(x, dim=1)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..common import Activation
from ..backbones.det_mobilenet_v3 import ConvBNLayer
class Head(nn.Module):
def __init__(self, in_channels, **kwargs):
super(Head, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels // 4,
kernel_size=3,
padding=1,
bias=False)
self.conv_bn1 = nn.BatchNorm2d(
in_channels // 4)
self.relu1 = Activation(act_type='relu')
self.conv2 = nn.ConvTranspose2d(
in_channels=in_channels // 4,
out_channels=in_channels // 4,
kernel_size=2,
stride=2)
self.conv_bn2 = nn.BatchNorm2d(
in_channels // 4)
self.relu2 = Activation(act_type='relu')
self.conv3 = nn.ConvTranspose2d(
in_channels=in_channels // 4,
out_channels=1,
kernel_size=2,
stride=2)
def forward(self, x, return_f=False):
x = self.conv1(x)
x = self.conv_bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.conv_bn2(x)
x = self.relu2(x)
if return_f is True:
f = x
x = self.conv3(x)
x = torch.sigmoid(x)
if return_f is True:
return x, f
return x
class DBHead(nn.Module):
"""
Differentiable Binarization (DB) for text detection:
see https://arxiv.org/abs/1911.08947
args:
params(dict): super parameters for build DB network
"""
def __init__(self, in_channels, k=50, **kwargs):
super(DBHead, self).__init__()
self.k = k
binarize_name_list = [
'conv2d_56', 'batch_norm_47', 'conv2d_transpose_0', 'batch_norm_48',
'conv2d_transpose_1', 'binarize'
]
thresh_name_list = [
'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
'conv2d_transpose_3', 'thresh'
]
self.binarize = Head(in_channels, **kwargs)# binarize_name_list)
self.thresh = Head(in_channels, **kwargs)#thresh_name_list)
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
def forward(self, x):
shrink_maps = self.binarize(x)
return {'maps': shrink_maps}
class LocalModule(nn.Module):
def __init__(self, in_c, mid_c, use_distance=True):
super(self.__class__, self).__init__()
self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
self.last_1 = nn.Conv2d(mid_c, 1, 1, 1, 0)
def forward(self, x, init_map, distance_map):
outf = torch.cat([init_map, x], dim=1)
# last Conv
out = self.last_1(self.last_3(outf))
return out
class PFHeadLocal(DBHead):
def __init__(self, in_channels, k=50, mode='small', **kwargs):
super(PFHeadLocal, self).__init__(in_channels, k, **kwargs)
self.mode = mode
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest")
if self.mode == 'large':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
elif self.mode == 'small':
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
def forward(self, x, targets=None):
shrink_maps, f = self.binarize(x, return_f=True)
base_maps = shrink_maps
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
cbn_maps = F.sigmoid(cbn_maps)
return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
\ No newline at end of file
import torch.nn.functional as F
from torch import nn
class CTCHead(nn.Module):
def __init__(
self,
in_channels,
out_channels=6625,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs
):
super(CTCHead, self).__init__()
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
bias=True,
)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
bias=True,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
bias=True,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = (x, predicts)
else:
result = predicts
if not self.training:
predicts = F.softmax(predicts, dim=2)
result = predicts
return result
from torch import nn
from ..necks.rnn import Im2Seq, SequenceEncoder
from .rec_ctc_head import CTCHead
class FCTranspose(nn.Module):
def __init__(self, in_channels, out_channels, only_transpose=False):
super().__init__()
self.only_transpose = only_transpose
if not self.only_transpose:
self.fc = nn.Linear(in_channels, out_channels, bias=False)
def forward(self, x):
if self.only_transpose:
return x.permute([0, 2, 1])
else:
return self.fc(x.permute([0, 2, 1]))
class MultiHead(nn.Module):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop("head_list")
self.gtc_head = "sar"
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
name = list(head_name)[0]
if name == "SARHead":
pass
elif name == "NRTRHead":
pass
elif name == "CTCHead":
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
neck_args = self.head_list[idx][name]["Neck"]
encoder_type = neck_args.pop("name")
self.ctc_encoder = SequenceEncoder(
in_channels=in_channels, encoder_type=encoder_type, **neck_args
)
# ctc head
head_args = self.head_list[idx][name].get("Head", {})
if head_args is None:
head_args = {}
self.ctc_head = CTCHead(
in_channels=self.ctc_encoder.out_channels,
out_channels=out_channels_list["CTCLabelDecode"],
**head_args,
)
else:
raise NotImplementedError(f"{name} is not supported in MultiHead yet")
def forward(self, x, data=None):
ctc_encoder = self.ctc_encoder(x)
return self.ctc_head(ctc_encoder)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["build_neck"]
def build_neck(config):
from .db_fpn import DBFPN, LKPAN, RSEFPN
from .rnn import SequenceEncoder
support_dict = ["DBFPN", "SequenceEncoder", "RSEFPN", "LKPAN"]
module_name = config.pop("name")
assert module_name in support_dict, Exception(
"neck only support {}".format(support_dict)
)
module_class = eval(module_name)(**config)
return module_class
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment