Unverified Commit 0c7a0882 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2611 from myhloli/dev

Dev
parents 3bd0ecf1 a392f445
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import copy
import os.path
import os
import warnings
from pathlib import Path
......@@ -9,8 +9,10 @@ import numpy as np
import yaml
from loguru import logger
from magic_pdf.libs.config_reader import get_device, get_local_models_dir
from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from .tools.infer.predict_system import TextSystem
from .tools.infer import pytorchocr_utility as utility
import argparse
......@@ -55,7 +57,7 @@ class PytorchPaddleOCR(TextSystem):
self.lang = kwargs.get('lang', 'ch')
device = get_device()
if device == 'cpu' and self.lang in ['ch', 'ch_server']:
if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
self.lang = 'ch_lite'
......@@ -74,9 +76,14 @@ class PytorchPaddleOCR(TextSystem):
with open(models_config_path) as file:
config = yaml.safe_load(file)
det, rec, dict_file = get_model_params(self.lang, config)
ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch')
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
ocr_models_dir = ModelPath.pytorch_paddle
det_model_path = f"{ocr_models_dir}/{det}"
det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
rec_model_path = f"{ocr_models_dir}/{rec}"
rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
kwargs['det_model_path'] = det_model_path
kwargs['rec_model_path'] = rec_model_path
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
# kwargs['rec_batch_num'] = 8
......
......@@ -20,6 +20,7 @@ def build_backbone(config, model_type):
from .det_mobilenet_v3 import MobileNetV3
from .rec_hgnet import PPHGNet_small
from .rec_lcnetv3 import PPLCNetV3
from .rec_pphgnetv2 import PPHGNetV2_B4
support_dict = [
"MobileNetV3",
......@@ -28,6 +29,7 @@ def build_backbone(config, model_type):
"ResNet_SAST",
"PPLCNetV3",
"PPHGNet_small",
'PPHGNetV2_B4',
]
elif model_type == "rec" or model_type == "cls":
from .rec_hgnet import PPHGNet_small
......
import collections.abc
from collections import OrderedDict
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class DonutSwinConfig(object):
model_type = "donut-swin"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
initializer_range=0.02,
layer_norm_eps=1e-5,
**kwargs,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
print(f"Can't set {key} with value {value} for {self}")
raise err
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
class DonutSwinEncoderOutput(OrderedDict):
last_hidden_state = None
hidden_states = None
attentions = None
reshaped_hidden_states = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
super().__setattr__(key, value)
def to_tuple(self):
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
class DonutSwinModelOutput(OrderedDict):
last_hidden_state = None
pooler_output = None
hidden_states = None
attentions = None
reshaped_hidden_states = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
super().__setattr__(key, value)
def to_tuple(self):
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.reshape(
[
batch_size,
height // window_size,
window_size,
width // window_size,
window_size,
num_channels,
]
)
windows = input_feature.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, window_size, window_size, num_channels]
)
return windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape[-1]
windows = windows.reshape(
[
-1,
height // window_size,
width // window_size,
window_size,
window_size,
num_channels,
]
)
windows = windows.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, height, width, num_channels]
)
return windows
# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
class DonutSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def __init__(self, config, use_mask_token=False):
super().__init__()
self.patch_embeddings = DonutSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
if use_mask_token:
# self.mask_token = paddle.create_parameter(
# [1, 1, config.embed_dim], dtype="float32"
# )
self.mask_token = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(1, 1, config.embed_dim).to(torch.float32))
)
nn.init.zeros_(self.mask_token)
else:
self.mask_token = None
if config.use_absolute_embeddings:
# self.position_embeddings = paddle.create_parameter(
# [1, num_patches + 1, config.embed_dim], dtype="float32"
# )
self.position_embeddings = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(1, num_patches + 1, config.embed_dim).to(torch.float32))
)
nn.init.zeros_(self.position_embedding)
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.shape
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class MyConv2d(nn.Conv2d):
def __init__(
self,
in_channel,
out_channels,
kernel_size,
stride=1,
padding="SAME",
dilation=1,
groups=1,
bias_attr=False,
eps=1e-6,
):
super().__init__(
in_channel,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias_attr,
)
# self.weight = paddle.create_parameter(
# [out_channels, in_channel, kernel_size[0], kernel_size[1]], dtype="float32"
# )
self.weight = torch.Parameter(
nn.init.xavier_uniform_(
torch.zeros(out_channels, in_channel, kernel_size[0], kernel_size[1]).to(torch.float32)
)
)
# self.bias = paddle.create_parameter([out_channels], dtype="float32")
self.bias = torch.Parameter(
nn.init.xavier_uniform_(
torch.zeros(out_channels).to(torch.float32)
)
)
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
x = F.conv2d(
x,
self.weight,
self.bias,
self._stride,
self._padding,
self._dilation,
self._groups,
)
return x
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
class DonutSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = (
image_size
if isinstance(image_size, collections.abc.Iterable)
else (image_size, image_size)
)
patch_size = (
patch_size
if isinstance(patch_size, collections.abc.Iterable)
else (patch_size, patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (
image_size[0] // patch_size[0]
)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.is_export = config.is_export
self.grid_size = (
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
)
self.projection = nn.Conv2D(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
if self.is_export:
pad_values = torch.tensor(pad_values, dtype=torch.int32)
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
if self.is_export:
pad_values = torch.tensor(pad_values, dtype=torch.int32)
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose([0, 2, 1])
return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class DonutSwinPatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(
self,
input_resolution: Tuple[int],
dim: int,
norm_layer: nn.Module = nn.LayerNorm,
is_export=False,
):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = norm_layer(4 * dim)
self.is_export = is_export
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
if self.is_export:
pad_values = torch.tensor(pad_values, dtype=torch.int32)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(
self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
) -> torch.Tensor:
height, width = input_dimensions
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.reshape([batch_size, height, width, num_channels])
input_feature = self.maybe_pad(input_feature, height, width)
input_feature_0 = input_feature[:, 0::2, 0::2, :]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
input_feature = torch.cat(
[input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
)
input_feature = input_feature.reshape(
[batch_size, -1, 4 * num_channels]
) # batch_size height/2*width/2 4*C
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (
input.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(
shape,
dtype=input.dtype,
)
random_tensor.floor_() # binarize
output = input / keep_prob * random_tensor
return output
# Copied from transformers.models.swin.modeling_swin.SwinDropPath
class DonutSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class DonutSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size
if isinstance(window_size, collections.abc.Iterable)
else (window_size, window_size)
)
# self.relative_position_bias_table = paddle.create_parameter(
# [(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads],
# dtype="float32",
# )
self.relative_position_bias_table = torch.Parameter(
nn.init.xavier_normal_(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads).to(torch.float32)
)
)
nn.init.zeros_(self.relative_position_bias_table)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.transpose([1, 2, 0])
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(
self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
)
self.key = nn.Linear(
self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
)
self.value = nn.Linear(
self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.shape[:-1] + [
self.num_attention_heads,
self.attention_head_size,
]
x = x.reshape(new_x_shape)
return x.transpose([0, 2, 1, 3])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask=None,
head_mask=None,
output_attentions=False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose([0, 1, 3, 2]))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.reshape([-1])
]
relative_position_bias = relative_position_bias.reshape(
[
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
]
)
relative_position_bias = relative_position_bias.transpose([2, 0, 1])
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.reshape(
[
batch_size // mask_shape,
mask_shape,
self.num_attention_heads,
dim,
dim,
]
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(
0
)
attention_scores = attention_scores.reshape(
[-1, self.num_attention_heads, dim, dim]
)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose([0, 2, 1, 3])
new_context_layer_shape = tuple(context_layer.shape[:-2]) + (
self.all_head_size,
)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
class DonutSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
class DonutSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size)
self.output = DonutSwinSelfOutput(config, dim)
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask=None,
head_mask=None,
output_attentions=False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
class DonutSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
self.intermediate_act_fn = F.gelu
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput
class DonutSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
class DonutSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(
config, dim, num_heads, window_size=self.window_size
)
self.drop_path = (
DonutSwinDropPath(config.drop_path_rate)
if config.drop_path_rate > 0.0
else nn.Identity()
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DonutSwinIntermediate(config, dim)
self.output = DonutSwinOutput(config, dim)
self.is_export = config.is_export
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask_export(self, height, width, dtype):
attn_mask = None
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
if self.shift_size > 0:
img_mask[:, height_slice, width_slice, :] = count
count += 1
if torch.Tensor(self.shift_size > 0).to(torch.bool):
# calculate attention mask for SW-MSA
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size]
)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size]
)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, 0, pad_bottom, 0, pad_right, 0, 0)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask=None,
output_attentions=False,
always_partition=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not always_partition:
self.set_shift_and_window_size(input_dimensions)
else:
pass
height, width = input_dimensions
batch_size, _, channels = hidden_states.shape
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.reshape([batch_size, height, width, channels])
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shift_value = (-self.shift_size, -self.shift_size)
if self.is_export:
shift_value = torch.tensor(shift_value, dtype=torch.int32)
shifted_hidden_states = torch.roll(
hidden_states, shifts=shift_value, dims=(1, 2)
)
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(
shifted_hidden_states, self.window_size
)
hidden_states_windows = hidden_states_windows.reshape(
[-1, self.window_size * self.window_size, channels]
)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
attention_outputs = self.attention(
hidden_states_windows,
attn_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
attention_windows = attention_output.reshape(
[-1, self.window_size, self.window_size, channels]
)
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
)
# reverse cyclic shift
if self.shift_size > 0:
shift_value = (self.shift_size, self.shift_size)
if self.is_export:
shift_value = torch.tensor(shift_value, dtype=torch.int32)
attention_windows = torch.roll(
shifted_windows, shifts=shift_value, dims=(1, 2)
)
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.reshape(
[batch_size, height * width, channels]
)
hidden_states = shortcut + self.drop_path(attention_windows)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_outputs = (
(layer_output, attention_outputs[1])
if output_attentions
else (layer_output,)
)
return layer_outputs
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class DonutSwinStage(nn.Module):
def __init__(
self, config, dim, input_resolution, depth, num_heads, drop_path, downsample
):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
DonutSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
self.is_export = config.is_export
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=nn.LayerNorm,
is_export=self.is_export,
)
else:
self.downsample = None
self.pointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask=None,
output_attentions=False,
always_partition=False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(
hidden_states_before_downsampling, input_dimensions
)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (
hidden_states,
hidden_states_before_downsampling,
output_dimensions,
)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
class DonutSwinEncoder(nn.Module):
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [
x.item()
for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))
]
self.layers = nn.ModuleList(
[
DonutSwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(
grid_size[0] // (2**i_layer),
grid_size[1] // (2**i_layer),
),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[
sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])
],
downsample=(
DonutSwinPatchMerging
if (i_layer < self.num_layers - 1)
else None
),
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask=None,
output_attentions=False,
output_hidden_states=False,
output_hidden_states_before_downsampling=False,
always_partition=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape
reshaped_hidden_state = hidden_states.view(
batch_size, *input_dimensions, hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
reshaped_hidden_state = hidden_states_before_downsampling.reshape(
[
batch_size,
*(output_dimensions[0], output_dimensions[1]),
hidden_size,
]
)
reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
reshaped_hidden_state = hidden_states.reshape(
[batch_size, *input_dimensions, hidden_size]
)
reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions:
all_self_attentions += layer_outputs[3:]
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions]
if v is not None
)
return DonutSwinEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
)
class DonutSwinPreTrainedModel(nn.Module):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DonutSwinConfig
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2D)):
# normal_ = Normal(mean=0.0, std=self.config.initializer_range)
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def _initialize_weights(self, module):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
def post_init(self):
self.apply(self._initialize_weights)
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
class DonutSwinModel(DonutSwinPreTrainedModel):
def __init__(
self,
in_channels=3,
hidden_size=1024,
num_layers=4,
num_heads=[4, 8, 16, 32],
add_pooling_layer=True,
use_mask_token=False,
is_export=False,
):
super().__init__()
donut_swin_config = {
"return_dict": True,
"output_hidden_states": False,
"output_attentions": False,
"use_bfloat16": False,
"tf_legacy_loss": False,
"pruned_heads": {},
"tie_word_embeddings": True,
"chunk_size_feed_forward": 0,
"is_encoder_decoder": False,
"is_decoder": False,
"cross_attention_hidden_size": None,
"add_cross_attention": False,
"tie_encoder_decoder": False,
"max_length": 20,
"min_length": 0,
"do_sample": False,
"early_stopping": False,
"num_beams": 1,
"num_beam_groups": 1,
"diversity_penalty": 0.0,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
"typical_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"encoder_no_repeat_ngram_size": 0,
"bad_words_ids": None,
"num_return_sequences": 1,
"output_scores": False,
"return_dict_in_generate": False,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"remove_invalid_values": False,
"exponential_decay_length_penalty": None,
"suppress_tokens": None,
"begin_suppress_tokens": None,
"architectures": None,
"finetuning_task": None,
"id2label": {0: "LABEL_0", 1: "LABEL_1"},
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
"tokenizer_class": None,
"prefix": None,
"bos_token_id": None,
"pad_token_id": None,
"eos_token_id": None,
"sep_token_id": None,
"decoder_start_token_id": None,
"task_specific_params": None,
"problem_type": None,
"_name_or_path": "",
"_commit_hash": None,
"_attn_implementation_internal": None,
"transformers_version": None,
"hidden_size": hidden_size,
"num_layers": num_layers,
"path_norm": True,
"use_2d_embeddings": False,
"image_size": [420, 420],
"patch_size": 4,
"num_channels": in_channels,
"embed_dim": 128,
"depths": [2, 2, 14, 2],
"num_heads": num_heads,
"window_size": 5,
"mlp_ratio": 4.0,
"qkv_bias": True,
"hidden_dropout_prob": 0.0,
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.1,
"hidden_act": "gelu",
"use_absolute_embeddings": False,
"layer_norm_eps": 1e-05,
"initializer_range": 0.02,
"is_export": is_export,
}
config = DonutSwinConfig(**donut_swin_config)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
self.pooler = nn.AdaptiveAvgPool1D(1) if add_pooling_layer else None
self.out_channels = hidden_size
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(
self,
input_data=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple, DonutSwinModelOutput]:
r"""
bool_masked_pos (`paddle.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
if self.training:
pixel_values, label, attention_mask = input_data
else:
if isinstance(input_data, list):
pixel_values = input_data[0]
else:
pixel_values = input_data
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos
)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose([0, 2, 1]))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
output = (sequence_output, pooled_output) + encoder_outputs[1:]
return output
donut_swin_output = DonutSwinModelOutput(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
if self.training:
return donut_swin_output, label, attention_mask
else:
return donut_swin_output
\ No newline at end of file
......@@ -2,37 +2,813 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .rec_donut_swin import DonutSwinModelOutput
from typing import List, Dict, Union, Callable
class AdaptiveAvgPool2D(nn.AdaptiveAvgPool2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(self.output_size, int) and self.output_size == 1:
self._gap = True
elif (
isinstance(self.output_size, tuple)
and self.output_size[0] == 1
and self.output_size[1] == 1
):
self._gap = True
class IdentityBasedConv1x1(nn.Conv2d):
def __init__(self, channels, groups=1):
super(IdentityBasedConv1x1, self).__init__(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False,
)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = torch.Tensor(id_value)
self.weight.set_value(torch.zeros_like(self.weight))
def forward(self, input):
kernel = self.weight + self.id_tensor
result = F.conv2d(
input,
kernel,
None,
stride=1,
padding=0,
dilation=self._dilation,
groups=self._groups,
)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor
class BNAndPad(nn.Module):
def __init__(
self,
pad_pixels,
num_features,
epsilon=1e-5,
momentum=0.1,
last_conv_bias=None,
bn=nn.BatchNorm2d,
):
super().__init__()
self.bn = bn(num_features, momentum=momentum, epsilon=epsilon)
self.pad_pixels = pad_pixels
self.last_conv_bias = last_conv_bias
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
bias = -self.bn._mean
if self.last_conv_bias is not None:
bias += self.last_conv_bias
pad_values = self.bn.bias + self.bn.weight * (
bias / torch.sqrt(self.bn._variance + self.bn._epsilon)
)
""" pad """
# TODO: n,h,w,c format is not supported yet
n, c, h, w = output.shape
values = pad_values.reshape([1, -1, 1, 1])
w_values = values.expand([n, -1, self.pad_pixels, w])
x = torch.cat([w_values, output, w_values], dim=2)
h = h + self.pad_pixels * 2
h_values = values.expand([n, -1, h, self.pad_pixels])
x = torch.cat([h_values, x, h_values], dim=3)
output = x
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def _mean(self):
return self.bn._mean
@property
def _variance(self):
return self.bn._variance
@property
def _epsilon(self):
return self.bn._epsilon
def conv_bn(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
):
conv_layer = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=False,
padding_mode=padding_mode,
)
bn_layer = nn.BatchNorm2D(num_features=out_channels)
se = nn.Sequential()
se.add_sublayer("conv", conv_layer)
se.add_sublayer("bn", bn_layer)
return se
def transI_fusebn(kernel, bn):
gamma = bn.weight
std = (bn._variance + bn._epsilon).sqrt()
return (
kernel * ((gamma / std).reshape([-1, 1, 1, 1])),
bn.bias - bn._mean * gamma / std,
)
def transII_addbranch(kernels, biases):
return sum(kernels), sum(biases)
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
if groups == 1:
k = F.conv2d(k2, k1.transpose([1, 0, 2, 3]))
b_hat = (k2 * b1.reshape([1, -1, 1, 1])).sum((1, 2, 3))
else:
k_slices = []
b_slices = []
k1_T = k1.transpose([1, 0, 2, 3])
k1_group_width = k1.shape[0] // groups
k2_group_width = k2.shape[0] // groups
for g in range(groups):
k1_T_slice = k1_T[:, g * k1_group_width : (g + 1) * k1_group_width, :, :]
k2_slice = k2[g * k2_group_width : (g + 1) * k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append(
(
k2_slice
* b1[g * k1_group_width : (g + 1) * k1_group_width].reshape(
[1, -1, 1, 1]
)
).sum((1, 2, 3))
)
k, b_hat = transIV_depthconcat(k_slices, b_slices)
return k, b_hat + b2
def transIV_depthconcat(kernels, biases):
return torch.cat(kernels, dim=0), torch.cat(biases)
def transV_avg(channels, kernel_size, groups):
input_dim = channels // groups
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = (
1.0 / kernel_size**2
)
return k
def transVI_multiscale(kernel, target_kernel_size):
H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2
W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2
return F.pad(
kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]
)
class DiverseBranchBlock(nn.Module):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
is_repped=False,
single_init=False,
**kwargs,
):
super().__init__()
padding = (filter_size - 1) // 2
dilation = 1
in_channels = num_channels
out_channels = num_filters
kernel_size = filter_size
internal_channels_1x1_3x3 = None
nonlinear = act
self.is_repped = is_repped
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self._gap = False
self.nonlinear = nn.ReLU()
def forward(self, x):
if self._gap:
# Global Average Pooling
N, C, _, _ = x.shape
x_mean = torch.mean(x, dim=[2, 3])
x_mean = torch.reshape(x_mean, [N, C, 1, 1])
return x_mean
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if is_repped:
self.dbb_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
)
else:
return F.adaptive_avg_pool2d(
x,
output_size=self.output_size
self.dbb_origin = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_sublayer(
"conv",
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False,
),
)
self.dbb_avg.add_sublayer(
"bn", BNAndPad(pad_pixels=padding, num_features=out_channels)
)
self.dbb_avg.add_sublayer(
"avg",
nn.AvgPool2D(kernel_size=kernel_size, stride=stride, padding=0),
)
self.dbb_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
groups=groups,
)
else:
self.dbb_avg.add_sublayer(
"avg",
nn.AvgPool2D(
kernel_size=kernel_size, stride=stride, padding=padding
),
)
self.dbb_avg.add_sublayer("avgbn", nn.BatchNorm2D(out_channels))
if internal_channels_1x1_3x3 is None:
internal_channels_1x1_3x3 = (
in_channels if groups < out_channels else 2 * in_channels
) # For mobilenet, it is better to have 2X internal channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_sublayer(
"idconv1", IdentityBasedConv1x1(channels=in_channels, groups=groups)
)
else:
self.dbb_1x1_kxk.add_sublayer(
"conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False,
),
)
self.dbb_1x1_kxk.add_sublayer(
"bn1",
BNAndPad(pad_pixels=padding, num_features=internal_channels_1x1_3x3),
)
self.dbb_1x1_kxk.add_sublayer(
"conv2",
nn.Conv2d(
in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias=False,
),
)
self.dbb_1x1_kxk.add_sublayer("bn2", nn.BatchNorm2D(out_channels))
# The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
if single_init:
# Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
self.single_init()
def forward(self, inputs):
if self.is_repped:
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, "dbb_1x1"):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return self.nonlinear(out)
def init_gamma(self, gamma_value):
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
if hasattr(self, "dbb_1x1"):
torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
if hasattr(self, "dbb_avg"):
torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
if hasattr(self, "dbb_1x1_kxk"):
torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
def single_init(self):
self.init_gamma(0.0)
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
def get_equivalent_kernel_bias(self):
k_origin, b_origin = transI_fusebn(
self.dbb_origin.conv.weight, self.dbb_origin.bn
)
if hasattr(self, "dbb_1x1"):
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
else:
k_1x1, b_1x1 = 0, 0
if hasattr(self.dbb_1x1_kxk, "idconv1"):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(
k_1x1_kxk_first, self.dbb_1x1_kxk.bn1
)
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(
self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2
)
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(
k_1x1_kxk_first,
b_1x1_kxk_first,
k_1x1_kxk_second,
b_1x1_kxk_second,
groups=self.groups,
)
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg, self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, "conv"):
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(
self.dbb_avg.conv.weight, self.dbb_avg.bn
)
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(
k_1x1_avg_first,
b_1x1_avg_first,
k_1x1_avg_second,
b_1x1_avg_second,
groups=self.groups,
)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
return transII_addbranch(
(k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
(b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged),
)
def re_parameterize(self):
if self.is_repped:
return
kernel, bias = self.get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2d(
in_channels=self.dbb_origin.conv._in_channels,
out_channels=self.dbb_origin.conv._out_channels,
kernel_size=self.dbb_origin.conv._kernel_size,
stride=self.dbb_origin.conv._stride,
padding=self.dbb_origin.conv._padding,
dilation=self.dbb_origin.conv._dilation,
groups=self.dbb_origin.conv._groups,
bias=True,
)
self.dbb_reparam.weight.set_value(kernel)
self.dbb_reparam.bias.set_value(bias)
self.__delattr__("dbb_origin")
self.__delattr__("dbb_avg")
if hasattr(self, "dbb_1x1"):
self.__delattr__("dbb_1x1")
self.__delattr__("dbb_1x1_kxk")
self.is_repped = True
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, inputs):
return inputs
class TheseusLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.res_dict = {}
# self.res_name = self.full_name()
self.res_name = self.__class__.__name__.lower()
self.pruner = None
self.quanter = None
self.init_net(*args, **kwargs)
def _return_dict_hook(self, layer, input, output):
res_dict = {"logits": output}
# 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict):
# clear the res_dict because the forward process may change according to input
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def init_net(
self,
stages_pattern=None,
return_patterns=None,
return_stages=None,
freeze_befor=None,
stop_after=None,
*args,
**kwargs,
):
# init the output of net
if return_patterns or return_stages:
if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
return_stages = [
val
for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
# call update_res function after the __init__ of the object has completed execution, that is, the constructing of layer or model has been completed.
def update_res_hook(layer, input):
self.update_res(return_patterns)
self.register_forward_pre_hook(update_res_hook)
# freeze subnet
if freeze_befor is not None:
self.freeze_befor(freeze_befor)
# set subnet to Identity
if stop_after is not None:
self.stop_after(stop_after)
def init_res(self, stages_pattern, return_patterns=None, return_stages=None):
if return_patterns and return_stages:
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
return_stages = [
val
for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
self.update_res(return_patterns)
def replace_sub(self, *args, **kwargs) -> None:
msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
raise DeprecationWarning(msg)
def upgrade_sublayer(
self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Module, str], nn.Module],
) -> Dict[str, nn.Module]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Module, str], nn.Module]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Module) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
Returns:
Dict[str, nn.Module]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
Examples:
from paddle import nn
import paddleclas
def rep_func(layer: nn.Module, pattern: str):
new_layer = nn.Conv2d(
in_channels=layer._in_channels,
out_channels=layer._out_channels,
kernel_size=5,
padding=2
)
return new_layer
net = paddleclas.MobileNetV1()
res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
hit_layer_pattern_list = []
for pattern in layer_name_pattern:
# parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list:
continue
sub_layer_parent = layer_list[-2]["layer"] if len(layer_list) > 1 else self
sub_layer = layer_list[-1]["layer"]
sub_layer_name = layer_list[-1]["name"]
sub_layer_index_list = layer_list[-1]["index_list"]
new_sub_layer = handle_func(sub_layer, pattern)
if sub_layer_index_list:
if len(sub_layer_index_list) > 1:
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[
sub_layer_index_list[0]
]
for sub_layer_index in sub_layer_index_list[1:-1]:
sub_layer_parent = sub_layer_parent[sub_layer_index]
sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
else:
getattr(sub_layer_parent, sub_layer_name)[
sub_layer_index_list[0]
] = new_sub_layer
else:
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
hit_layer_pattern_list.append(pattern)
return hit_layer_pattern_list
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
Args:
stop_layer_name (str): The name of layer that stop forward and backward after this layer.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
layer_list = parse_pattern_str(stop_layer_name, self)
if not layer_list:
return False
parent_layer = self
for layer_dict in layer_list:
name, index_list = layer_dict["name"], layer_dict["index_list"]
if not set_identity(parent_layer, name, index_list):
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
return False
parent_layer = layer_dict["layer"]
return True
def freeze_befor(self, layer_name: str) -> bool:
"""freeze the layer named layer_name and its previous layer.
Args:
layer_name (str): The name of layer that would be freezed.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
def stop_grad(layer, pattern):
class StopGradLayer(nn.Module):
def __init__(self):
super().__init__()
self.layer = layer
def forward(self, x):
x = self.layer(x)
x.stop_gradient = True
return x
new_layer = StopGradLayer()
return new_layer
res = self.upgrade_sublayer(layer_name, stop_grad)
if len(res) == 0:
msg = "Failed to stop the gradient before the layer named '{layer_name}'"
return False
return True
def update_res(self, return_patterns: Union[str, List[str]]) -> Dict[str, nn.Module]:
"""update the result(s) to be returned.
Args:
return_patterns (Union[str, List[str]]): The name of layer to return output.
Returns:
Dict[str, nn.Module]: The pattern(str) and corresponding layer(nn.Module) that have been set successfully.
"""
# clear res_dict that could have been set
self.res_dict = {}
class Handler(object):
def __init__(self, res_dict):
# res_dict is a reference
self.res_dict = res_dict
def __call__(self, layer, pattern):
layer.res_dict = self.res_dict
layer.res_name = pattern
if hasattr(layer, "hook_remove_helper"):
layer.hook_remove_helper.remove()
layer.hook_remove_helper = layer.register_forward_post_hook(
save_sub_res_hook
)
return layer
handle_func = Handler(self.res_dict)
hit_layer_pattern_list = self.upgrade_sublayer(
return_patterns, handle_func=handle_func
)
if hasattr(self, "hook_remove_helper"):
self.hook_remove_helper.remove()
self.hook_remove_helper = self.register_forward_post_hook(
self._return_dict_hook
)
return hit_layer_pattern_list
def save_sub_res_hook(layer, input, output):
layer.res_dict[layer.res_name] = output
def set_identity(
parent_layer: nn.Module, layer_name: str, layer_index_list: str = None
) -> bool:
"""set the layer specified by layer_name and layer_index_list to Identity.
Args:
parent_layer (nn.Module): The parent layer of target layer specified by layer_name and layer_index_list.
layer_name (str): The name of target layer to be set to Identity.
layer_index_list (str, optional): The index of target layer to be set to Identity in parent_layer. Defaults to None.
Returns:
bool: True if successfully, False otherwise.
"""
stop_after = False
for sub_layer_name in parent_layer._sub_layers:
if stop_after:
parent_layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index_list and stop_after:
layer_container = parent_layer._sub_layers[layer_name]
for num, layer_index in enumerate(layer_index_list):
stop_after = False
for i in range(num):
layer_container = layer_container[layer_index_list[i]]
for sub_layer_index in layer_container._sub_layers:
if stop_after:
parent_layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def parse_pattern_str(
pattern: str, parent_layer: nn.Module
) -> Union[None, List[Dict[str, Union[nn.Module, str, None]]]]:
"""parse the string type pattern.
Args:
pattern (str): The pattern to describe layer.
parent_layer (nn.Module): The root layer relative to the pattern.
Returns:
Union[None, List[Dict[str, Union[nn.Module, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
...
]
"""
pattern_list = pattern.split(".")
if not pattern_list:
msg = f"The pattern('{pattern}') is illegal. Please check and retry."
return None
layer_list = []
while len(pattern_list) > 0:
if "[" in pattern_list[0]:
target_layer_name = pattern_list[0].split("[")[0]
target_layer_index_list = list(
index.split("]")[0] for index in pattern_list[0].split("[")[1:]
)
else:
target_layer_name = pattern_list[0]
target_layer_index_list = None
target_layer = getattr(parent_layer, target_layer_name, None)
if target_layer is None:
msg = f"Not found layer named('{target_layer_name}') specified in pattern('{pattern}')."
return None
if target_layer_index_list:
for target_layer_index in target_layer_index_list:
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
target_layer
):
msg = f"Not found layer by index('{target_layer_index}') specified in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
return None
target_layer = target_layer[target_layer_index]
layer_list.append(
{
"layer": target_layer,
"name": target_layer_name,
"index_list": target_layer_index_list,
}
)
class LearnableAffineBlock(nn.Module):
pattern_list = pattern_list[1:]
parent_layer = target_layer
return layer_list
class LearnableAffineBlock(TheseusLayer):
"""
Create a learnable affine block module. This module can significantly improve accuracy on smaller models.
......@@ -45,14 +821,41 @@ class LearnableAffineBlock(nn.Module):
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01):
super().__init__()
self.scale = nn.Parameter(torch.Tensor([scale_value]))
self.bias = nn.Parameter(torch.Tensor([bias_value]))
# self.scale = self.create_parameter(
# shape=[
# 1,
# ],
# default_initializer=nn.init.Constant(value=scale_value),
# # attr=ParamAttr(learning_rate=lr_mult * lab_lr),
# )
# self.add_parameter("scale", self.scale)
self.scale = torch.Parameter(
nn.init.constant_(
torch.ones(1).to(torch.float32), val=scale_value
)
)
self.register_parameter("scale", self.scale)
# self.bias = self.create_parameter(
# shape=[
# 1,
# ],
# default_initializer=nn.init.Constant(value=bias_value),
# # attr=ParamAttr(learning_rate=lr_mult * lab_lr),
# )
# self.add_parameter("bias", self.bias)
self.bias = torch.Parameter(
nn.init.constant_(
torch.ones(1).to(torch.float32), val=bias_value
)
)
self.register_parameter("bias", self.bias)
def forward(self, x):
return self.scale * x + self.bias
class ConvBNAct(nn.Module):
class ConvBNAct(TheseusLayer):
"""
ConvBNAct is a combination of convolution and batchnorm layers.
......@@ -83,14 +886,12 @@ class ConvBNAct(nn.Module):
super().__init__()
self.use_act = use_act
self.use_lab = use_lab
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2,
# padding=(kernel_size - 1) // 2,
groups=groups,
bias=False,
)
......@@ -112,7 +913,7 @@ class ConvBNAct(nn.Module):
return x
class LightConvBNAct(nn.Module):
class LightConvBNAct(TheseusLayer):
"""
LightConvBNAct is a combination of pw and dw layers.
......@@ -158,84 +959,24 @@ class LightConvBNAct(nn.Module):
return x
class CustomMaxPool2d(nn.Module):
def __init__(
self,
kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False,
data_format="NCHW",
):
super(CustomMaxPool2d, self).__init__()
self.kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, kernel_size)
self.stride = stride if stride is not None else self.kernel_size
self.stride = self.stride if isinstance(self.stride, (tuple, list)) else (self.stride, self.stride)
self.dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
self.return_indices = return_indices
self.ceil_mode = ceil_mode
self.padding_mode = padding
# 当padding不是"same"时使用标准MaxPool2d
if padding != "same":
self.padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
self.pool = nn.MaxPool2d(
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
return_indices=self.return_indices,
ceil_mode=self.ceil_mode
)
class PaddingSameAsPaddleMaxPool2d(torch.nn.Module):
def __init__(self, kernel_size, stride=1):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.pool = torch.nn.MaxPool2d(kernel_size, stride, padding=0, ceil_mode=True)
def forward(self, x):
# 处理same padding
if self.padding_mode == "same":
input_height, input_width = x.size(2), x.size(3)
# 计算期望的输出尺寸
out_height = math.ceil(input_height / self.stride[0])
out_width = math.ceil(input_width / self.stride[1])
# 计算需要的padding
pad_height = max((out_height - 1) * self.stride[0] + self.kernel_size[0] - input_height, 0)
pad_width = max((out_width - 1) * self.stride[1] + self.kernel_size[1] - input_width, 0)
# 将padding分配到两边
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
# 应用padding
x = F.pad(x, (pad_left, pad_right, pad_top, pad_bottom))
# 使用标准max_pool2d函数
if self.return_indices:
return F.max_pool2d_with_indices(
x,
kernel_size=self.kernel_size,
stride=self.stride,
padding=0, # 已经手动pad过了
dilation=self.dilation,
ceil_mode=self.ceil_mode
)
else:
return F.max_pool2d(
x,
kernel_size=self.kernel_size,
stride=self.stride,
padding=0, # 已经手动pad过了
dilation=self.dilation,
ceil_mode=self.ceil_mode
)
else:
# 使用预定义的MaxPool2d
return self.pool(x)
_, _, h, w = x.shape
pad_h_total = max(0, (math.ceil(h / self.stride) - 1) * self.stride + self.kernel_size - h)
pad_w_total = max(0, (math.ceil(w / self.stride) - 1) * self.stride + self.kernel_size - w)
pad_h = pad_h_total // 2
pad_w = pad_w_total // 2
x = torch.nn.functional.pad(x, [pad_w, pad_w_total - pad_w, pad_h, pad_h_total - pad_h])
return self.pool(x)
class StemBlock(nn.Module):
class StemBlock(TheseusLayer):
"""
StemBlock for PP-HGNetV2.
......@@ -299,22 +1040,15 @@ class StemBlock(nn.Module):
use_lab=use_lab,
lr_mult=lr_mult,
)
self.pool = CustomMaxPool2d(
kernel_size=2, stride=1, ceil_mode=True, padding="same"
self.pool = PaddingSameAsPaddleMaxPool2d(
kernel_size=2, stride=1,
)
# self.pool = nn.MaxPool2d(
# kernel_size=2, stride=1, ceil_mode=True, padding=1
# )
def forward(self, x):
x = self.stem1(x)
x2 = self.stem2a(x)
x2 = self.stem2b(x2)
x1 = self.pool(x)
# if x1.shape[2:] != x2.shape[2:]:
# x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x1, x2], 1)
x = self.stem3(x)
x = self.stem4(x)
......@@ -322,7 +1056,7 @@ class StemBlock(nn.Module):
return x
class HGV2_Block(nn.Module):
class HGV2_Block(TheseusLayer):
"""
HGV2_Block, the basic unit that constitutes the HGV2_Stage.
......@@ -402,7 +1136,7 @@ class HGV2_Block(nn.Module):
return x
class HGV2_Stage(nn.Module):
class HGV2_Stage(TheseusLayer):
"""
HGV2_Stage, the basic unit that constitutes the PPHGNetV2.
......@@ -472,26 +1206,7 @@ class HGV2_Stage(nn.Module):
return x
class DropoutInferDownscale(nn.Module):
"""
实现与Paddle的mode="downscale_in_infer"等效的Dropout
训练模式:out = input * mask(直接应用掩码,不进行放大)
推理模式:out = input * (1.0 - p)(在推理时按概率缩小)
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, x):
if self.training:
# 训练时:应用随机mask但不放大
return F.dropout(x, self.p, training=True) * (1.0 - self.p)
else:
# 推理时:按照dropout概率缩小输出
return x * (1.0 - self.p)
class PPHGNetV2(nn.Module):
class PPHGNetV2(TheseusLayer):
"""
PPHGNetV2
......@@ -505,7 +1220,7 @@ class PPHGNetV2(nn.Module):
class_num (int): The number of classes for the classification layer. Defaults to 1000.
lr_mult_list (list): Learning rate multiplier for the stages. Defaults to [1.0, 1.0, 1.0, 1.0, 1.0].
Returns:
model: nn.Layer. Specific PPHGNetV2 model depends on args.
model: nn.Module. Specific PPHGNetV2 model depends on args.
"""
def __init__(
......@@ -577,7 +1292,7 @@ class PPHGNetV2(nn.Module):
if not self.det:
self.out_channels = stage_config["stage4"][2]
self.avg_pool = AdaptiveAvgPool2D(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if self.use_last_conv:
self.last_conv = nn.Conv2d(
......@@ -591,7 +1306,8 @@ class PPHGNetV2(nn.Module):
self.act = nn.ReLU()
if self.use_lab:
self.lab = LearnableAffineBlock()
self.dropout = DropoutInferDownscale(p=dropout_prob)
# self.dropout = nn.Dropout(p=dropout_prob, mode="downscale_in_infer")
self.dropout = nn.Dropout(p=dropout_prob)
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
if not self.det:
......@@ -606,7 +1322,7 @@ class PPHGNetV2(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, (nn.BatchNorm2d)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
......@@ -638,7 +1354,7 @@ def PPHGNetV2_B0(pretrained=False, use_ssld=False, **kwargs):
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B0` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B0` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
......@@ -662,7 +1378,7 @@ def PPHGNetV2_B1(pretrained=False, use_ssld=False, **kwargs):
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B1` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B1` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
......@@ -686,7 +1402,7 @@ def PPHGNetV2_B2(pretrained=False, use_ssld=False, **kwargs):
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B2` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B2` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
......@@ -710,7 +1426,7 @@ def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B3` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B3` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
......@@ -734,7 +1450,7 @@ def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B4` model depends on args.
"""
stage_config_rec = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
......@@ -770,7 +1486,7 @@ def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B5` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B5` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
......@@ -794,7 +1510,7 @@ def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B6` model depends on args.
model: nn.Module. Specific `PPHGNetV2_B6` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
......@@ -808,3 +1524,119 @@ def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
stem_channels=[3, 48, 96], stage_config=stage_config, use_lab=False, **kwargs
)
return model
class PPHGNetV2_B4_Formula(nn.Module):
"""
PPHGNetV2_B4_Formula
Args:
in_channels (int): Number of input channels. Default is 3 (for RGB images).
class_num (int): Number of classes for classification. Default is 1000.
Returns:
model: nn.Module. Specific `PPHGNetV2_B4` model with defined architecture.
"""
def __init__(self, in_channels=3, class_num=1000):
super().__init__()
self.in_channels = in_channels
self.out_channels = 2048
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
}
self.pphgnet_b4 = PPHGNetV2(
stem_channels=[3, 32, 48],
stage_config=stage_config,
class_num=class_num,
use_lab=False,
)
def forward(self, input_data):
if self.training:
pixel_values, label, attention_mask = input_data
else:
if isinstance(input_data, list):
pixel_values = input_data[0]
else:
pixel_values = input_data
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
pphgnet_b4_output = self.pphgnet_b4(pixel_values)
b, c, h, w = pphgnet_b4_output.shape
pphgnet_b4_output = pphgnet_b4_output.reshape([b, c, h * w]).transpose(
[0, 2, 1]
)
pphgnet_b4_output = DonutSwinModelOutput(
last_hidden_state=pphgnet_b4_output,
pooler_output=None,
hidden_states=None,
attentions=False,
reshaped_hidden_states=None,
)
if self.training:
return pphgnet_b4_output, label, attention_mask
else:
return pphgnet_b4_output
class PPHGNetV2_B6_Formula(nn.Module):
"""
PPHGNetV2_B6_Formula
Args:
in_channels (int): Number of input channels. Default is 3 (for RGB images).
class_num (int): Number of classes for classification. Default is 1000.
Returns:
model: nn.Module. Specific `PPHGNetV2_B6` model with defined architecture.
"""
def __init__(self, in_channels=3, class_num=1000):
super().__init__()
self.in_channels = in_channels
self.out_channels = 2048
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [96, 96, 192, 2, False, False, 3, 6, 2],
"stage2": [192, 192, 512, 3, True, False, 3, 6, 2],
"stage3": [512, 384, 1024, 6, True, True, 5, 6, 2],
"stage4": [1024, 768, 2048, 3, True, True, 5, 6, 2],
}
self.pphgnet_b6 = PPHGNetV2(
stem_channels=[3, 48, 96],
class_num=class_num,
stage_config=stage_config,
use_lab=False,
)
def forward(self, input_data):
if self.training:
pixel_values, label, attention_mask = input_data
else:
if isinstance(input_data, list):
pixel_values = input_data[0]
else:
pixel_values = input_data
num_channels = pixel_values.shape[1]
if num_channels == 1:
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
pphgnet_b6_output = self.pphgnet_b6(pixel_values)
b, c, h, w = pphgnet_b6_output.shape
pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).transpose(
[0, 2, 1]
)
pphgnet_b6_output = DonutSwinModelOutput(
last_hidden_state=pphgnet_b6_output,
pooler_output=None,
hidden_states=None,
attentions=False,
reshaped_hidden_states=None,
)
if self.training:
return pphgnet_b6_output, label, attention_mask
else:
return pphgnet_b6_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment