Commit a7a899f6 authored by myhloli's avatar myhloli
Browse files

feat(model): add OCR model base structure and utilities

- Add base model structure for OCR in pytorch
- Implement data augmentation and transformation modules
- Create utilities for dictionary handling and state dict conversion
- Include post-processing modules for OCR
- Add weight initialization and loading functions
parent 72e66c2d
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = act
if self.act is not None:
self._act = Activation(act_type=self.act, inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.act is not None:
x = self._act(x)
return x
class PGHead(nn.Module):
"""
"""
def __init__(self, in_channels, **kwargs):
super(PGHead, self).__init__()
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(1))
self.conv_f_score2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_score{}".format(2))
self.conv_f_score3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(3))
self.conv1 = nn.Conv2d(
in_channels=128,
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
groups=1,
bias=False)
self.conv_f_boder1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(1))
self.conv_f_boder2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_boder{}".format(2))
self.conv_f_boder3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(3))
self.conv2 = nn.Conv2d(
in_channels=128,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
groups=1,
bias=False)
self.conv_f_char1 = ConvBNLayer(
in_channels=in_channels,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(1))
self.conv_f_char2 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(2))
self.conv_f_char3 = ConvBNLayer(
in_channels=128,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(3))
self.conv_f_char4 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(4))
self.conv_f_char5 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2d(
in_channels=256,
out_channels=37,
kernel_size=3,
stride=1,
padding=1,
groups=1,
bias=False)
self.conv_f_direc1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(1))
self.conv_f_direc2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_direc{}".format(2))
self.conv_f_direc3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(3))
self.conv4 = nn.Conv2d(
in_channels=128,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
groups=1,
bias=False)
def forward(self, x):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
f_score = self.conv1(f_score)
f_score = torch.sigmoid(f_score)
# f_border
f_border = self.conv_f_boder1(x)
f_border = self.conv_f_boder2(f_border)
f_border = self.conv_f_boder3(f_border)
f_border = self.conv2(f_border)
f_char = self.conv_f_char1(x)
f_char = self.conv_f_char2(f_char)
f_char = self.conv_f_char3(f_char)
f_char = self.conv_f_char4(f_char)
f_char = self.conv_f_char5(f_char)
f_char = self.conv3(f_char)
f_direction = self.conv_f_direc1(x)
f_direction = self.conv_f_direc2(f_direction)
f_direction = self.conv_f_direc3(f_direction)
f_direction = self.conv4(f_direction)
predicts = {}
predicts['f_score'] = f_score
predicts['f_border'] = f_border
predicts['f_char'] = f_char
predicts['f_direction'] = f_direction
return predicts
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Linear
from torch.nn.init import xavier_uniform_
class MultiheadAttention(nn.Module):
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
"""
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
bias=True,
add_bias_kv=False,
add_zero_attn=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
self._reset_parameters()
self.conv1 = torch.nn.Conv2d(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = torch.nn.Conv2d(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv3 = torch.nn.Conv2d(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
def _reset_parameters(self):
xavier_uniform_(self.out_proj.weight)
def forward(self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
attn_mask=None):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
q_shape = query.shape
src_shape = key.shape
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling
# q = paddle.transpose(
# paddle.reshape(
# q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
# [1, 2, 0, 3])
q = torch.reshape(q, (q_shape[0], q_shape[1], self.num_heads, self.head_dim))
q = q.permute(1, 2, 0, 3)
# k = paddle.transpose(
# paddle.reshape(
# k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
# [1, 2, 0, 3])
k = torch.reshape(k, (src_shape[0], q_shape[1], self.num_heads, self.head_dim))
k = k.permute(1, 2, 0, 3)
# v = paddle.transpose(
# paddle.reshape(
# v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
# [1, 2, 0, 3])
v = torch.reshape(v, (src_shape[0], q_shape[1], self.num_heads, self.head_dim))
v = v.permute(1, 2, 0, 3)
if key_padding_mask is not None:
assert key_padding_mask.shape[0] == q_shape[1]
assert key_padding_mask.shape[1] == src_shape[0]
attn_output_weights = torch.matmul(q,
k.permute(0, 1, 3, 2))
if attn_mask is not None:
attn_mask = torch.unsqueeze(torch.unsqueeze(attn_mask, 0), 0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = torch.reshape(
attn_output_weights,
[q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
key = torch.unsqueeze(torch.unsqueeze(key_padding_mask, 1), 2)
key = key.type(torch.float32)
y = torch.full(
size=key.shape, fill_value=float("-Inf"), dtype=torch.float32)
y = torch.where(key == 0., key, y)
attn_output_weights += y
attn_output_weights = F.softmax(
attn_output_weights.type(torch.float32),
dim=-1,
dtype=torch.float32 if attn_output_weights.dtype == torch.float16
else attn_output_weights.dtype)
attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_output_weights, v)
attn_output = torch.reshape(
attn_output.permute(2, 0, 1, 3),
[q_shape[0], q_shape[1], self.embed_dim])
attn_output = self.out_proj(attn_output)
return attn_output
def _in_proj_q(self, query):
query = query.permute(1, 2, 0)
query = torch.unsqueeze(query, dim=2)
res = self.conv1(query)
res = torch.squeeze(res, dim=2)
res = res.permute(2, 0, 1)
return res
def _in_proj_k(self, key):
key = key.permute(1, 2, 0)
key = torch.unsqueeze(key, dim=2)
res = self.conv2(key)
res = torch.squeeze(res, dim=2)
res = res.permute(2, 0, 1)
return res
def _in_proj_v(self, value):
value = value.permute(1, 2, 0) #(1, 2, 0)
value = torch.unsqueeze(value, dim=2)
res = self.conv3(value)
res = torch.squeeze(res, dim=2)
res = res.permute(2, 0, 1)
return res
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
class AttentionHead(nn.Module):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionHead, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionGRUCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char.type(torch.int64), onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.size()[0]
num_steps = batch_max_length
hidden = torch.zeros((batch_size, self.hidden_size))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
output_hiddens.append(torch.unsqueeze(outputs, dim=1))
output = torch.cat(output_hiddens, dim=1)
probs = self.generator(output)
else:
targets = torch.zeros([batch_size], dtype=torch.int32)
probs = None
char_onehots = None
outputs = None
alpha = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(outputs)
if probs is None:
probs = torch.unsqueeze(probs_step, dim=1)
else:
probs = torch.cat(
[probs, torch.unsqueeze(
probs_step, dim=1)], dim=1)
next_input = probs_step.argmax(dim=1)
targets = next_input
return probs
class AttentionGRUCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size, bias=True)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)
res = torch.add(batch_H_proj, prev_hidden_proj)
res = torch.tanh(res)
e = self.score(res)
alpha = F.softmax(e, dim=1)
alpha = alpha.permute(0, 2, 1)
context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
concat_context = torch.cat([context, char_onehots.float()], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return (cur_hidden, cur_hidden), alpha
class AttentionLSTM(nn.Module):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char.type(torch.int64), onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (torch.zeros((batch_size, self.hidden_size)), torch.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(torch.unsqueeze(hidden[0], dim=1))
output = torch.cat(output_hiddens, dim=1)
probs = self.generator(output)
else:
targets = torch.zeros([batch_size], dtype=torch.int32)
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = torch.unsqueeze(probs_step, dim=1)
else:
probs = torch.cat(
[probs, torch.unsqueeze(
probs_step, dim=1)], dim=1)
next_input = probs_step.argmax(dim=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden[0]), dim=1)
res = torch.add(batch_H_proj, prev_hidden_proj)
res = torch.tanh(res)
e = self.score(res)
alpha = F.softmax(e, dim=1)
alpha = alpha.permute(0, 2, 1)
context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
concat_context = torch.cat([context, char_onehots.float()], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
"""
This code is refer from:
https://github.com/LBH1024/CAN/models/can.py
https://github.com/LBH1024/CAN/models/counting.py
https://github.com/LBH1024/CAN/models/decoder.py
https://github.com/LBH1024/CAN/models/attention.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.nn as nn
import torch
import math
'''
Counting Module
'''
class ChannelAtt(nn.Module):
def __init__(self, channel, reduction):
super(ChannelAtt, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.shape
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class CountingDecoder(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size):
super(CountingDecoder, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.trans_layer = nn.Sequential(
nn.Conv2d(
self.in_channel,
512,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=False),
nn.BatchNorm2d(512))
self.channel_att = ChannelAtt(512, 16)
self.pred_layer = nn.Sequential(
nn.Conv2d(
512, self.out_channel, kernel_size=1, bias=False),
nn.Sigmoid())
def forward(self, x, mask):
b, _, h, w = x.shape
x = self.trans_layer(x)
x = self.channel_att(x)
x = self.pred_layer(x)
if mask is not None:
x = x * mask
x = x.view(b, self.out_channel, -1)
x1 = torch.sum(x, dim=-1)
return x1, x.view(b, self.out_channel, h, w)
'''
Attention Decoder
'''
class PositionEmbeddingSine(nn.Module):
def __init__(self,
num_pos_feats=64,
temperature=10000,
normalize=False,
scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask):
y_embed = mask.cumsum(1, dtype=torch.float32)
x_embed = mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
# dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
# dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') /
# self.num_pos_feats)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = torch.unsqueeze(x_embed, 3) / dim_t
pos_y = torch.unsqueeze(y_embed, 3) / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class AttDecoder(nn.Module):
def __init__(self, ratio, is_train, input_size, hidden_size,
encoder_out_channel, dropout, dropout_ratio, word_num,
counting_decoder_out_channel, attention):
super(AttDecoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.out_channel = encoder_out_channel
self.attention_dim = attention['attention_dim']
self.dropout_prob = dropout
self.ratio = ratio
self.word_num = word_num
self.counting_num = counting_decoder_out_channel
self.is_train = is_train
self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
self.embedding = nn.Embedding(self.word_num, self.input_size)
self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
self.word_attention = Attention(hidden_size, attention['attention_dim'])
self.encoder_feature_conv = nn.Conv2d(
self.out_channel,
self.attention_dim,
kernel_size=attention['word_conv_kernel'],
padding=attention['word_conv_kernel'] // 2)
self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
self.word_embedding_weight = nn.Linear(self.input_size,
self.hidden_size)
self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
self.counting_context_weight = nn.Linear(self.counting_num,
self.hidden_size)
self.word_convert = nn.Linear(self.hidden_size, self.word_num)
if dropout:
self.dropout = nn.Dropout(dropout_ratio)
def forward(self, cnn_features, labels, counting_preds, images_mask):
if self.is_train:
_, num_steps = labels.shape
else:
num_steps = 36
batch_size, _, height, width = cnn_features.shape
images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
word_probs = torch.zeros((batch_size, num_steps, self.word_num)).to(device=cnn_features.device)
word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=cnn_features.device)
hidden = self.init_hidden(cnn_features, images_mask)
counting_context_weighted = self.counting_context_weight(counting_preds)
cnn_features_trans = self.encoder_feature_conv(cnn_features)
position_embedding = PositionEmbeddingSine(256, normalize=True)
pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
cnn_features_trans = cnn_features_trans + pos
word = torch.ones([batch_size]).long().to(device=cnn_features.device) # init word as sos
for i in range(num_steps):
word_embedding = self.embedding(word)
hidden = self.word_input_gru(word_embedding, hidden)
word_context_vec, _, word_alpha_sum = self.word_attention(
cnn_features, cnn_features_trans, hidden, word_alpha_sum,
images_mask)
current_state = self.word_state_weight(hidden)
word_weighted_embedding = self.word_embedding_weight(word_embedding)
word_context_weighted = self.word_context_weight(word_context_vec)
if self.dropout_prob:
word_out_state = self.dropout(
current_state + word_weighted_embedding +
word_context_weighted + counting_context_weighted)
else:
word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
word_prob = self.word_convert(word_out_state)
word_probs[:, i] = word_prob
if self.is_train:
word = labels[:, i]
else:
word = word_prob.argmax(1)
word = torch.mul(
word, labels[:, i]
) # labels are oneslike tensor in infer/predict mode, torch.multiply
return word_probs
def init_hidden(self, features, feature_mask):
average = torch.sum(torch.sum(features * feature_mask, dim=-1),
dim=-1) / torch.sum(
(torch.sum(feature_mask, dim=-1)), dim=-1)
average = self.init_weight(average)
return torch.tanh(average)
'''
Attention Module
'''
class Attention(nn.Module):
def __init__(self, hidden_size, attention_dim):
super(Attention, self).__init__()
self.hidden = hidden_size
self.attention_dim = attention_dim
self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
self.attention_conv = nn.Conv2d(
1, 512, kernel_size=11, padding=5, bias=False)
self.attention_weight = nn.Linear(
512, self.attention_dim, bias=False)
self.alpha_convert = nn.Linear(self.attention_dim, 1)
def forward(self,
cnn_features,
cnn_features_trans,
hidden,
alpha_sum,
image_mask=None):
query = self.hidden_weight(hidden)
alpha_sum_trans = self.attention_conv(alpha_sum)
coverage_alpha = self.attention_weight(alpha_sum_trans.permute(0, 2, 3, 1))
alpha_score = torch.tanh(
query[:, None, None, :] + coverage_alpha + cnn_features_trans.permute(0, 2, 3, 1)
)
energy = self.alpha_convert(alpha_score)
energy = energy - energy.max()
energy_exp = torch.exp(torch.squeeze(energy, -1))
if image_mask is not None:
energy_exp = energy_exp * torch.squeeze(image_mask, 1)
alpha = energy_exp / (energy_exp.sum(-1).sum(-1)[:,None,None] + 1e-10)
alpha_sum = torch.unsqueeze(alpha, 1) + alpha_sum
context_vector = torch.sum(
torch.sum((torch.unsqueeze(alpha, 1) * cnn_features), -1), -1)
return context_vector, alpha, alpha_sum
class CANHead(nn.Module):
def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
super(CANHead, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.counting_decoder1 = CountingDecoder(self.in_channel,
self.out_channel, 3) # mscm
self.counting_decoder2 = CountingDecoder(self.in_channel,
self.out_channel, 5)
self.decoder = AttDecoder(ratio, **attdecoder)
self.ratio = ratio
def forward(self, inputs, targets=None):
cnn_features, images_mask, labels = inputs
counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
counting_preds = (counting_preds1 + counting_preds2) / 2
word_probs = self.decoder(cnn_features, labels, counting_preds,
images_mask)
return word_probs, counting_preds, counting_preds1, counting_preds2
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
class CTCHead(nn.Module):
def __init__(self,
in_channels,
out_channels=6625,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs):
super(CTCHead, self).__init__()
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
bias=True,)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
bias=True,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
bias=True,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = (x, predicts)
else:
result = predicts
if not self.training:
predicts = F.softmax(predicts, dim=2)
result = predicts
return result
\ No newline at end of file
import torch
import torch.nn as nn
from pytorchocr.modeling.necks.rnn import Im2Seq, SequenceEncoder
from .rec_nrtr_head import Transformer
from .rec_ctc_head import CTCHead
from .rec_sar_head import SARHead
class FCTranspose(nn.Module):
def __init__(self, in_channels, out_channels, only_transpose=False):
super().__init__()
self.only_transpose = only_transpose
if not self.only_transpose:
self.fc = nn.Linear(in_channels, out_channels, bias=False)
def forward(self, x):
if self.only_transpose:
return x.permute([0, 2, 1])
else:
return self.fc(x.permute([0, 2, 1]))
class MultiHead(nn.Module):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop('head_list')
self.gtc_head = 'sar'
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
name = list(head_name)[0]
if name == 'SARHead':
pass
# # sar head
# sar_args = self.head_list[idx][name]
# self.sar_head = eval(name)(in_channels=in_channels, \
# out_channels=out_channels_list['SARLabelDecode'], **sar_args)
elif name == 'NRTRHead':
pass
# gtc_args = self.head_list[idx][name]
# max_text_length = gtc_args.get('max_text_length', 25)
# nrtr_dim = gtc_args.get('nrtr_dim', 256)
# num_decoder_layers = gtc_args.get('num_decoder_layers', 4)
# self.before_gtc = nn.Sequential(
# nn.Flatten(2), FCTranspose(in_channels, nrtr_dim))
# self.gtc_head = Transformer(
# d_model=nrtr_dim,
# nhead=nrtr_dim // 32,
# num_encoder_layers=-1,
# beam_size=-1,
# num_decoder_layers=num_decoder_layers,
# max_len=max_text_length,
# dim_feedforward=nrtr_dim * 4,
# out_channels=out_channels_list['NRTRLabelDecode'])
elif name == 'CTCHead':
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
neck_args = self.head_list[idx][name]['Neck']
encoder_type = neck_args.pop('name')
self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \
encoder_type=encoder_type, **neck_args)
# ctc head
head_args = self.head_list[idx][name].get('Head', {})
if head_args is None:
head_args = {}
self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \
out_channels=out_channels_list['CTCLabelDecode'], **head_args)
else:
raise NotImplementedError(
'{} is not supported in MultiHead yet'.format(name))
def forward(self, x, data=None):
ctc_encoder = self.ctc_encoder(x)
ctc_out = self.ctc_head(ctc_encoder)
head_out = dict()
head_out['ctc'] = ctc_out
head_out['res'] = ctc_out
head_out['ctc_neck'] = ctc_encoder
# eval mode
if not self.training:
return ctc_out
if self.gtc_head == 'sar':
sar_out = self.sar_head(x, data[1:])['res']
head_out['sar'] = sar_out
else:
gtc_out = self.gtc_head(self.before_gtc(x), data[1:])['res']
head_out['nrtr'] = gtc_out
return head_out
import math
import torch
import copy
from torch import nn
import torch.nn.functional as F
from torch.nn import ModuleList as LayerList
from torch.nn.init import xavier_uniform_
from torch.nn import Dropout, LayerNorm, Conv2d
import numpy as np
from pytorchocr.modeling.heads.multiheadAttention import MultiheadAttention
from torch.nn.init import xavier_normal_
class Transformer(nn.Module):
"""A transformer model. User is able to modify the attributes as needed. The architechture
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
Processing Systems, pages 6000-6010.
Args:
d_model: the number of expected features in the encoder/decoder inputs (default=512).
nhead: the number of heads in the multiheadattention models (default=8).
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).
"""
def __init__(self,
d_model=512,
nhead=8,
num_encoder_layers=6,
beam_size=0,
num_decoder_layers=6,
max_len=25,
dim_feedforward=1024,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1,
custom_encoder=None,
custom_decoder=None,
in_channels=0,
out_channels=0,
scale_embedding=True):
super(Transformer, self).__init__()
self.out_channels = out_channels # out_channels + 1
self.max_len = max_len
self.embedding = Embeddings(
d_model=d_model,
vocab=self.out_channels,
padding_idx=0,
scale_embedding=scale_embedding)
self.positional_encoding = PositionalEncoding(
dropout=residual_dropout_rate,
dim=d_model, )
if custom_encoder is not None:
self.encoder = custom_encoder
else:
if num_encoder_layers > 0:
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, attention_dropout_rate,
residual_dropout_rate)
self.encoder = TransformerEncoder(encoder_layer,
num_encoder_layers)
else:
self.encoder = None
if custom_decoder is not None:
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, attention_dropout_rate,
residual_dropout_rate)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
self._reset_parameters()
self.beam_size = beam_size
self.d_model = d_model
self.nhead = nhead
self.tgt_word_prj = nn.Linear(
d_model, self.out_channels, bias=False)
w0 = np.random.normal(0.0, d_model ** -0.5,
(self.out_channels, d_model)).astype(np.float32)
self.tgt_word_prj.weight.data = torch.from_numpy(w0)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
xavier_normal_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def forward_train(self, src, tgt):
tgt = tgt[:, :-1]
tgt_key_padding_mask = self.generate_padding_mask(tgt)
tgt = self.embedding(tgt).permute(1, 0, 2)
tgt = self.positional_encoding(tgt)
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0], tgt.device)
if self.encoder is not None:
src = self.positional_encoding(src.permute(1, 0, 2))
memory = self.encoder(src)
else:
memory = src.squeeze(2).permute(2, 0, 1)
output = self.decoder(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=None,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=None)
output = output.permute(1, 0, 2)
logit = self.tgt_word_prj(output)
return logit
def forward(self, src, targets=None):
"""Take in and process masked source/target sequences.
Args:
src: the sequence to the encoder (required).
tgt: the sequence to the decoder (required).
Shape:
- src: :math:`(S, N, E)`.
- tgt: :math:`(T, N, E)`.
Examples:
>>> output = transformer_model(src, tgt)
"""
if self.training:
max_len = targets[1].max()
tgt = targets[0][:, :2 + max_len]
return self.forward_train(src, tgt)
else:
if self.beam_size > 0:
return self.forward_beam(src)
else:
return self.forward_test(src)
def forward_test(self, src):
bs = src.shape[0]
if self.encoder is not None:
src = self.positional_encoding(src.permute(1, 0, 2))
memory = self.encoder(src)
else:
memory = torch.squeeze(src, 2).permute(2, 0, 1)
dec_seq = torch.full((bs, 1), 2, dtype=torch.int64)
dec_prob = torch.full((bs, 1), 1., dtype=torch.float32)
for len_dec_seq in range(1, 25):
dec_seq_embed = self.embedding(dec_seq).permute(1, 0, 2)
dec_seq_embed = self.positional_encoding(dec_seq_embed)
tgt_mask = self.generate_square_subsequent_mask(
dec_seq_embed.shape[0])
output = self.decoder(
dec_seq_embed,
memory,
tgt_mask=tgt_mask,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None)
dec_output = output.permute(1, 0, 2)
dec_output = dec_output[:, -1, :]
tgt_word_prj = self.tgt_word_prj(dec_output)
word_prob = F.softmax(tgt_word_prj, dim=1)
preds_idx = word_prob.argmax(dim=1)
if torch.equal(
preds_idx,
torch.full(
preds_idx.shape, 3, dtype=torch.int64)):
break
preds_prob = torch.max(word_prob, dim=1).values
dec_seq = torch.cat(
[dec_seq, torch.reshape(preds_idx, (-1, 1))], dim=1)
dec_prob = torch.cat(
[dec_prob, torch.reshape(preds_prob, (-1, 1))], dim=1)
return [dec_seq, dec_prob]
def forward_beam(self, images):
''' Translation work in one batch '''
def get_inst_idx_to_tensor_position_map(inst_idx_list):
''' Indicate the position of an instance in a tensor. '''
return {
inst_idx: tensor_position
for tensor_position, inst_idx in enumerate(inst_idx_list)
}
def collect_active_part(beamed_tensor, curr_active_inst_idx,
n_prev_active_inst, n_bm):
''' Collect tensor parts associated to active instances. '''
beamed_tensor_shape = beamed_tensor.shape
n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
beamed_tensor_shape[2])
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select(
curr_active_inst_idx, axis=0)
beamed_tensor = beamed_tensor.reshape(new_shape)
return beamed_tensor
def collate_active_info(src_enc, inst_idx_to_position_map,
active_inst_idx_list):
# Sentences which are still active are collected,
# so the decoder will not run on completed sentences.
n_prev_active_inst = len(inst_idx_to_position_map)
active_inst_idx = [
inst_idx_to_position_map[k] for k in active_inst_idx_list
]
active_inst_idx = torch.tensor(active_inst_idx, dtype=torch.int64)
active_src_enc = collect_active_part(
src_enc.permute(1, 0, 2), active_inst_idx,
n_prev_active_inst, n_bm).permute(1, 0, 2)
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
active_inst_idx_list)
return active_src_enc, active_inst_idx_to_position_map
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
inst_idx_to_position_map, n_bm,
memory_key_padding_mask):
''' Decode and update beam status, and then return active beam idx '''
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
dec_partial_seq = [
b.get_current_state() for b in inst_dec_beams if not b.done
]
dec_partial_seq = torch.stack(dec_partial_seq)
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
return dec_partial_seq
def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
memory_key_padding_mask):
dec_seq = self.embedding(dec_seq).permute(1, 0, 2)
dec_seq = self.positional_encoding(dec_seq)
tgt_mask = self.generate_square_subsequent_mask(
dec_seq.shape[0])
dec_output = self.decoder(
dec_seq,
enc_output,
tgt_mask=tgt_mask,
tgt_key_padding_mask=None,
memory_key_padding_mask=memory_key_padding_mask, )
dec_output = dec_output.permute(1, 0, 2)
dec_output = dec_output[:,
-1, :] # Pick the last step: (bh * bm) * d_h
word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=1)
word_prob = torch.reshape(word_prob, (n_active_inst, n_bm, -1))
return word_prob
def collect_active_inst_idx_list(inst_beams, word_prob,
inst_idx_to_position_map):
active_inst_idx_list = []
for inst_idx, inst_position in inst_idx_to_position_map.items():
is_inst_complete = inst_beams[inst_idx].advance(word_prob[
inst_position])
if not is_inst_complete:
active_inst_idx_list += [inst_idx]
return active_inst_idx_list
n_active_inst = len(inst_idx_to_position_map)
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
None)
# Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list(
inst_dec_beams, word_prob, inst_idx_to_position_map)
return active_inst_idx_list
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
all_hyp, all_scores = [], []
for inst_idx in range(len(inst_dec_beams)):
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
all_scores += [scores[:n_best]]
hyps = [
inst_dec_beams[inst_idx].get_hypothesis(i)
for i in tail_idxs[:n_best]
]
all_hyp += [hyps]
return all_hyp, all_scores
with torch.no_grad():
#-- Encode
if self.encoder is not None:
src = self.positional_encoding(images.permute(1, 0, 2))
src_enc = self.encoder(src)
else:
src_enc = images.squeeze(2).transpose([0, 2, 1])
n_bm = self.beam_size
src_shape = src_enc.shape
inst_dec_beams = [Beam(n_bm) for _ in range(1)]
active_inst_idx_list = list(range(1))
# Repeat data for beam search
# src_enc = paddle.tile(src_enc, [1, n_bm, 1])
src_enc = src_enc.repeat(1, n_bm, 1)
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
active_inst_idx_list)
# Decode
for len_dec_seq in range(1, 25):
src_enc_copy = src_enc.clone()
active_inst_idx_list = beam_decode_step(
inst_dec_beams, len_dec_seq, src_enc_copy,
inst_idx_to_position_map, n_bm, None)
if not active_inst_idx_list:
break # all instances have finished their path to <EOS>
src_enc, inst_idx_to_position_map = collate_active_info(
src_enc_copy, inst_idx_to_position_map,
active_inst_idx_list)
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
1)
result_hyp = []
hyp_scores = []
for bs_hyp, score in zip(batch_hyp, batch_scores):
l = len(bs_hyp[0])
bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
result_hyp.append(bs_hyp_pad)
score = float(score) / l
hyp_score = [score for _ in range(25)]
hyp_scores.append(hyp_score)
return [
torch.tensor(
np.array(result_hyp), dtype=torch.int64),
torch.tensor(hyp_scores)
]
def generate_square_subsequent_mask(self, sz):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = torch.zeros([sz, sz], dtype=torch.float32)
mask_inf = torch.triu(
torch.full(
size=[sz, sz], fill_value=float('-Inf'), dtype=torch.float32),
diagonal=1)
mask = mask + mask_inf
return mask
def generate_padding_mask(self, x):
# padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
padding_mask = (x == torch.tensor(0, dtype=x.dtype))
return padding_mask
def _reset_parameters(self):
"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
class TransformerEncoder(nn.Module):
"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
"""
def __init__(self, encoder_layer, num_layers):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
def forward(self, src):
"""Pass the input through the endocder layers in turn.
Args:
src: the sequnce to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
"""
output = src
for i in range(self.num_layers):
output = self.layers[i](output,
src_mask=None,
src_key_padding_mask=None)
return output
class TransformerDecoder(nn.Module):
"""TransformerDecoder is a stack of N decoder layers
Args:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
"""
def __init__(self, decoder_layer, num_layers):
super(TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def forward(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None):
"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
memory: the sequnce from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
"""
output = tgt
for i in range(self.num_layers):
output = self.layers[i](
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output
class TransformerEncoderLayer(nn.Module):
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
"""
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.conv1 = nn.Conv2d(
in_channels=d_model,
out_channels=dim_feedforward,
kernel_size=(1, 1))
self.conv2 = nn.Conv2d(
in_channels=dim_feedforward,
out_channels=d_model,
kernel_size=(1, 1))
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(residual_dropout_rate)
self.dropout2 = Dropout(residual_dropout_rate)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""Pass the input through the endocder layer.
Args:
src: the sequnce to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
"""
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
src = src.permute(1, 2, 0)
src = torch.unsqueeze(src, 2)
src2 = self.conv2(F.relu(self.conv1(src)))
src2 = torch.squeeze(src2, 2)
src2 = src2.permute(2, 0, 1)
src = torch.squeeze(src, 2)
src = src.permute(2, 0, 1)
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class TransformerDecoderLayer(nn.Module):
"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
"""
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.multihead_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.conv1 = nn.Conv2d(
in_channels=d_model,
out_channels=dim_feedforward,
kernel_size=(1, 1))
self.conv2 = nn.Conv2d(
in_channels=dim_feedforward,
out_channels=d_model,
kernel_size=(1, 1))
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout1 = Dropout(residual_dropout_rate)
self.dropout2 = Dropout(residual_dropout_rate)
self.dropout3 = Dropout(residual_dropout_rate)
def forward(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None):
"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequnce from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
"""
tgt2 = self.self_attn(
tgt,
tgt,
tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
tgt,
memory,
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# default
tgt = tgt.permute(1, 2, 0)
tgt = torch.unsqueeze(tgt, 2)
tgt2 = self.conv2(F.relu(self.conv1(tgt)))
tgt2 = torch.squeeze(tgt2, 2)
tgt2 = tgt2.permute(2, 0, 1)
tgt = torch.squeeze(tgt, 2)
tgt = tgt.permute(2, 0, 1)
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def _get_clones(module, N):
return LayerList([copy.deepcopy(module) for i in range(N)])
class PositionalEncoding(nn.Module):
"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, dropout, dim, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros([max_len, dim])
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, dim, 2).type(torch.float32) *
(-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = torch.unsqueeze(pe, 0)
pe = pe.permute(1, 0, 2)
self.register_buffer('pe', pe)
def forward(self, x):
"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
class PositionalEncoding_2d(nn.Module):
"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, dropout, dim, max_len=5000):
super(PositionalEncoding_2d, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros([max_len, dim])
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, dim, 2).type(torch.float32) *
(-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = torch.unsqueeze(pe, 0).permute(1, 0, 2)
self.register_buffer('pe', pe)
self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1))
self.linear1 = nn.Linear(dim, dim)
self.linear1.weight.data.fill_(1.)
self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1))
self.linear2 = nn.Linear(dim, dim)
self.linear2.weight.data.fill_(1.)
def forward(self, x):
"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
w_pe = self.pe[:x.shape[-1], :]
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
w_pe = w_pe * w1
w_pe = w_pe.permute(1, 2, 0)
w_pe = torch.unsqueeze(w_pe, 2)
h_pe = self.pe[:x.shape[-2], :]
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
h_pe = h_pe * w2
h_pe = h_pe.permute(1, 2, 0)
h_pe = torch.unsqueeze(h_pe, 3)
x = x + w_pe + h_pe
x = torch.reshape(
x, [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]
).permute(2,0,1)
return self.dropout(x)
class Embeddings(nn.Module):
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
super(Embeddings, self).__init__()
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
w0 = np.random.normal(0.0, d_model**-0.5,
(vocab, d_model)).astype(np.float32)
self.embedding.weight.data = torch.from_numpy(w0)
self.d_model = d_model
self.scale_embedding = scale_embedding
def forward(self, x):
if self.scale_embedding:
x = self.embedding(x)
return x * math.sqrt(self.d_model)
return self.embedding(x)
class Beam():
''' Beam search '''
def __init__(self, size, device=False):
self.size = size
self._done = False
# The score for each translation on the beam.
self.scores = torch.zeros((size, ), dtype=torch.float32)
self.all_scores = []
# The backpointers at each time-step.
self.prev_ks = []
# The outputs at each time-step.
self.next_ys = [torch.full((size, ), 0, dtype=torch.int64)]
self.next_ys[0][0] = 2
def get_current_state(self):
"Get the outputs for the current timestep."
return self.get_tentative_hypothesis()
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
@property
def done(self):
return self._done
def advance(self, word_prob):
"Update beam status and check if finished or not."
num_words = word_prob.shape[1]
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
else:
beam_lk = word_prob[0]
flat_beam_lk = beam_lk.reshape([-1])
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
True) # 1st sort
self.all_scores.append(self.scores)
self.scores = best_scores
# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)
# End condition is when top-of-beam is EOS.
if self.next_ys[-1][0] == 3:
self._done = True
self.all_scores.append(self.scores)
return self._done
def sort_scores(self):
"Sort the scores."
return self.scores, torch.tensor(
[i for i in range(int(self.scores.shape[0]))], dtype=torch.int32)
def get_the_best_score_and_idx(self):
"Get the score of the best in the beam."
scores, ids = self.sort_scores()
return scores[1], ids[1]
def get_tentative_hypothesis(self):
"Get the decoded sequence for the current timestep."
if len(self.next_ys) == 1:
dec_seq = self.next_ys[0].unsqueeze(1)
else:
_, keys = self.sort_scores()
hyps = [self.get_hypothesis(k) for k in keys]
hyps = [[2] + h for h in hyps]
dec_seq = torch.tensor(hyps, dtype=torch.int64)
return dec_seq
def get_hypothesis(self, k):
""" Walk back to construct the full hypothesis. """
hyp = []
for j in range(len(self.prev_ks) - 1, -1, -1):
hyp.append(self.next_ys[j + 1][k])
k = self.prev_ks[j][k]
return list(map(lambda x: x.item(), hyp[::-1]))
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# import paddle
# from paddle import ParamAttr
# import paddle.nn as nn
# import paddle.nn.functional as F
class SAREncoder(nn.Module):
"""
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def __init__(self,
enc_bi_rnn=False,
enc_drop_rnn=0.0,
enc_gru=False,
d_model=512,
d_enc=512,
mask=True,
**kwargs):
super().__init__()
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_drop_rnn, (int, float))
assert 0 <= enc_drop_rnn < 1.0
assert isinstance(enc_gru, bool)
assert isinstance(d_model, int)
assert isinstance(d_enc, int)
assert isinstance(mask, bool)
self.enc_bi_rnn = enc_bi_rnn
self.enc_drop_rnn = enc_drop_rnn
self.mask = mask
# LSTM Encoder
# if enc_bi_rnn:
# direction = 'bidirectional'
# else:
# direction = 'forward'
kwargs = dict(
input_size=d_model,
hidden_size=d_enc,
num_layers=2,
batch_first=True,
dropout=enc_drop_rnn,
bidirectional=enc_bi_rnn)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def forward(self, feat, img_metas=None):
if img_metas is not None:
assert len(img_metas[0]) == feat.size(0)
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
h_feat = feat.shape[2] # bsz c h w
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
if valid_ratios is not None:
valid_hf = []
T = holistic_feat.size(1)
for i in range(valid_ratios.size(0)):
valid_step = torch.min(T, torch.ceil(T * valid_ratios[i])) - 1
# valid_step = paddle.minimum(
# T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1
valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = torch.stack(valid_hf, dim=0)
else:
valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat
class BaseDecoder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward_train(self, feat, out_enc, targets, img_metas):
raise NotImplementedError
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
def forward(self,
feat,
out_enc,
label=None,
img_metas=None,
train_mode=True):
self.train_mode = train_mode
if train_mode:
return self.forward_train(feat, out_enc, label, img_metas)
return self.forward_test(feat, out_enc, img_metas)
class ParallelSARDecoder(BaseDecoder):
"""
Args:
out_channels (int): Output class number.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder.
dec_gru (bool): If True, use GRU, else LSTM in decoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
d_k (int): Dim of channels of attention module.
pred_dropout (float): Dropout probability of prediction layer.
max_seq_len (int): Maximum sequence length for decoding.
mask (bool): If True, mask padding in feature map.
start_idx (int): Index of start token.
padding_idx (int): Index of padding token.
pred_concat (bool): If True, concat glimpse feature from
attention with holistic feature and hidden state.
"""
def __init__(
self,
out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
dec_gru=False,
d_model=512,
d_enc=512,
d_k=64,
pred_dropout=0.0,
max_text_length=30,
mask=True,
pred_concat=True,
**kwargs):
super().__init__()
self.num_classes = out_channels
self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k
self.start_idx = out_channels - 2
self.padding_idx = out_channels - 1
self.max_seq_len = max_text_length
self.mask = mask
self.pred_concat = pred_concat
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
# 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2d(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer
# if dec_bi_rnn:
# direction = 'bidirectional'
# else:
# direction = 'forward'
kwargs = dict(
input_size=encoder_rnn_out_size,
hidden_size=encoder_rnn_out_size,
num_layers=2,
batch_first=True,
dropout=dec_drop_rnn,
bidirectional=dec_bi_rnn)
if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs)
else:
self.rnn_decoder = nn.LSTM(**kwargs)
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes,
encoder_rnn_out_size,
padding_idx=self.padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = self.num_classes - 1
if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
def _2d_attention(self,
decoder_input,
feat,
holistic_feat,
valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
bsz, seq_len, attn_size = attn_query.shape
# attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)
# attn_query = attn_query.unsqueeze(3).unsqueeze(4)
# (bsz, seq_len + 1, attn_size, 1, 1)
attn_key = self.conv3x3_1(feat)
# bsz * attn_size * h * w
attn_key = attn_key.unsqueeze(1)
# bsz * 1 * attn_size * h * w
attn_weight = torch.tanh(torch.add(attn_key, attn_query))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight = self.conv1x1_2(attn_weight)
# bsz * (seq_len + 1) * h * w * 1
bsz, T, h, w, c = attn_weight.size()
assert c == 1
if valid_ratios is not None:
# cal mask of attention weight
for i in range(valid_ratios.size(0)):
valid_width = torch.min(w, torch.ceil(w * valid_ratios[i]))
# valid_width = paddle.minimum(
# w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
if valid_width < w:
attn_weight[i, :, :, valid_width:, :] = float('-inf')
# attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = attn_weight.view(bsz, T, -1)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_weight = attn_weight.view(bsz, T, h, w,
c).permute(0, 1, 4, 2, 3).contiguous()
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
attn_feat = torch.sum(
torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
# bsz * (seq_len + 1) * C
# Linear transformation
if self.pred_concat:
hf_c = holistic_feat.shape[-1]
holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
if self.train_mode:
y = self.pred_dropout(y)
return y
def forward_train(self, feat, out_enc, label, img_metas):
'''
img_metas: [label, valid_ratio]
'''
if img_metas is not None:
assert img_metas[0].size(0) == feat.size(0)
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
in_dec = torch.cat((out_enc, lab_embedding), dim=1)
# bsz * (seq_len + 1) * C
out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios)
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas):
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
seq_len = self.max_seq_len
bsz = feat.size(0)
start_token = torch.full(
(bsz, ), fill_value=self.start_idx, device=feat.device,dtype=torch.long)
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
emb_dim = start_token.shape[1]
# start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
start_token = start_token.unsqueeze(1).expand(bsz, seq_len, emb_dim)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = torch.cat((out_enc, start_token), dim=1)
# bsz * (seq_len + 1) * emb_dim
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim
if i < seq_len:
decoder_input[:, i + 1, :] = char_embedding
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs
class SARHead(nn.Module):
def __init__(self,
in_channels,
out_channels,
enc_dim=512,
max_text_length=30,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
enc_gru=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
dec_gru=False,
d_k=512,
pred_dropout=0.1,
pred_concat=True,
**kwargs):
super(SARHead, self).__init__()
# encoder module
self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn,
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru,
d_model=in_channels,
d_enc=enc_dim)
# decoder module
self.decoder = ParallelSARDecoder(
out_channels=out_channels,
enc_bi_rnn=enc_bi_rnn,
dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru,
d_model=in_channels,
d_enc=enc_dim,
d_k=d_k,
pred_dropout=pred_dropout,
max_text_length=max_text_length,
pred_concat=pred_concat)
def forward(self, feat, targets=None):
'''
img_metas: [label, valid_ratio]
'''
holistic_feat = self.encoder(feat, targets) # bsz c
if self.training:
label = targets[0] # label
final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
else:
final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
train_mode=False)
# (bsz, seq_len, num_classes)
return final_out
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
import numpy as np
from .self_attention import WrapEncoderForFeature
from .self_attention import WrapEncoder
from collections import OrderedDict
gradient_clip = 10
# https://forums.fast.ai/t/lambda-layer/28507/5
class Lambda(nn.Module):
"An easy way to create a pytorch layer for a simple `func`."
def __init__(self, func):
"create a layer that simply calls `func` with `x`"
super().__init__()
self.func=func
def forward(self, x):
return self.func(x)
class PVAM(nn.Module):
def __init__(self, in_channels, char_num, max_text_length, num_heads,
num_encoder_tus, hidden_dims):
super(PVAM, self).__init__()
self.char_num = char_num
self.max_length = max_text_length
self.num_heads = num_heads
self.num_encoder_TUs = num_encoder_tus
self.hidden_dims = hidden_dims
# Transformer encoder
t = 256
c = 512
self.wrap_encoder_for_feature = WrapEncoderForFeature(
src_vocab_size=1,
max_length=t,
n_layer=self.num_encoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.0,#0.1,
attention_dropout=0.0,#0.1,
relu_dropout=0.0,#0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True)
# PVAM
self.flatten0 = Lambda(lambda x: torch.flatten(x, start_dim=0, end_dim=1))
self.fc0 = torch.nn.Linear(
in_features=in_channels,
out_features=in_channels, )
self.emb = torch.nn.Embedding(
num_embeddings=self.max_length, embedding_dim=in_channels)
self.flatten1 = Lambda(lambda x: torch.flatten(x, start_dim=0, end_dim=2))
self.fc1 = torch.nn.Linear(
in_features=in_channels, out_features=1, bias=False)
def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
b, c, h, w = inputs.shape
conv_features = torch.reshape(inputs, shape=[-1, c, h * w])
conv_features = conv_features.permute(0, 2, 1)
# transformer encoder
b, t, c = conv_features.shape
enc_inputs = [conv_features, encoder_word_pos, None]
word_features = self.wrap_encoder_for_feature(enc_inputs)
# pvam
b, t, c = word_features.shape
word_features = self.fc0(word_features)
word_features_ = torch.reshape(word_features, [-1, 1, t, c])
word_features_ = word_features_.repeat([1, self.max_length, 1, 1])
word_pos_feature = self.emb(gsrm_word_pos)
word_pos_feature_ = torch.reshape(word_pos_feature,
[-1, self.max_length, 1, c])
word_pos_feature_ = word_pos_feature_.repeat([1, 1, t, 1])
y = word_pos_feature_ + word_features_
y = torch.tanh(y)
attention_weight = self.fc1(y)
attention_weight = torch.reshape(
attention_weight, shape=[-1, self.max_length, t])
attention_weight = F.softmax(attention_weight, dim=-1)
pvam_features = torch.matmul(attention_weight,
word_features) #[b, max_length, c]
return pvam_features
class GSRM(nn.Module):
def __init__(self, in_channels, char_num, max_text_length, num_heads,
num_encoder_tus, num_decoder_tus, hidden_dims):
super(GSRM, self).__init__()
self.char_num = char_num
self.max_length = max_text_length
self.num_heads = num_heads
self.num_encoder_TUs = num_encoder_tus
self.num_decoder_TUs = num_decoder_tus
self.hidden_dims = hidden_dims
self.fc0 = torch.nn.Linear(
in_features=in_channels, out_features=self.char_num)
self.wrap_encoder0 = WrapEncoder(
src_vocab_size=self.char_num + 1,
max_length=self.max_length,
n_layer=self.num_decoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True)
self.wrap_encoder1 = WrapEncoder(
src_vocab_size=self.char_num + 1,
max_length=self.max_length,
n_layer=self.num_decoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True)
self.mul = lambda x: torch.matmul(x,
self.wrap_encoder0.prepare_decoder.emb0.weight.t(),
)
def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2):
# ===== GSRM Visual-to-semantic embedding block =====
b, t, c = inputs.shape
pvam_features = torch.reshape(inputs, [-1, c])
word_out = self.fc0(pvam_features)
word_ids = torch.argmax(F.softmax(word_out, dim=-1), dim=1)
word_ids = torch.reshape(word_ids, shape=[-1, t, 1])
#===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
"""
pad_idx = self.char_num
word1 = F.pad(word_ids.type(torch.float32), [0, 0, 1, 0, 0, 0], value=1.0 * pad_idx)
word1 = word1.type(torch.int64)
word1 = word1[:, :-1, :]
word2 = word_ids
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
gsrm_feature2 = F.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0],
value=0.,
)
gsrm_feature2 = gsrm_feature2[:, 1:, ]
gsrm_features = gsrm_feature1 + gsrm_feature2
gsrm_out = self.mul(gsrm_features)
b, t, c = gsrm_out.shape
gsrm_out = torch.reshape(gsrm_out, [-1, c])
return gsrm_features, word_out, gsrm_out
class VSFD(nn.Module):
def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
super(VSFD, self).__init__()
self.char_num = char_num
self.fc0 = torch.nn.Linear(
in_features=in_channels * 2, out_features=pvam_ch)
self.fc1 = torch.nn.Linear(
in_features=pvam_ch, out_features=self.char_num)
def forward(self, pvam_feature, gsrm_feature):
b, t, c1 = pvam_feature.shape
b, t, c2 = gsrm_feature.shape
combine_feature_ = torch.cat([pvam_feature, gsrm_feature], dim=2)
img_comb_feature_ = torch.reshape(
combine_feature_, shape=[-1, c1 + c2])
img_comb_feature_map = self.fc0(img_comb_feature_)
img_comb_feature_map = torch.sigmoid(img_comb_feature_map)
img_comb_feature_map = torch.reshape(
img_comb_feature_map, shape=[-1, t, c1])
combine_feature = img_comb_feature_map * pvam_feature + (
1.0 - img_comb_feature_map) * gsrm_feature
img_comb_feature = torch.reshape(combine_feature, shape=[-1, c1])
out = self.fc1(img_comb_feature)
return out
class SRNHead(nn.Module):
def __init__(self, in_channels, out_channels, max_text_length, num_heads,
num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
super(SRNHead, self).__init__()
self.char_num = out_channels
self.max_length = max_text_length
self.num_heads = num_heads
self.num_encoder_TUs = num_encoder_TUs
self.num_decoder_TUs = num_decoder_TUs
self.hidden_dims = hidden_dims
self.pvam = PVAM(
in_channels=in_channels,
char_num=self.char_num,
max_text_length=self.max_length,
num_heads=self.num_heads,
num_encoder_tus=self.num_encoder_TUs,
hidden_dims=self.hidden_dims)
self.gsrm = GSRM(
in_channels=in_channels,
char_num=self.char_num,
max_text_length=self.max_length,
num_heads=self.num_heads,
num_encoder_tus=self.num_encoder_TUs,
num_decoder_tus=self.num_decoder_TUs,
hidden_dims=self.hidden_dims)
self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
def forward(self, inputs, others):
encoder_word_pos = others[0]
gsrm_word_pos = others[1].type(torch.long)
gsrm_slf_attn_bias1 = others[2]
gsrm_slf_attn_bias2 = others[3]
pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
gsrm_feature, word_out, gsrm_out = self.gsrm(
pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
final_out = self.vsfd(pvam_feature, gsrm_feature)
if not self.training:
final_out = F.softmax(final_out, dim=1)
_, decoded_out = torch.topk(final_out, k=1)
predicts = OrderedDict([
('predict', final_out),
('pvam_feature', pvam_feature),
('decoded_out', decoded_out),
('word_out', word_out),
('gsrm_out', gsrm_out),
])
return predicts
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
import numpy as np
gradient_clip = 10
class WrapEncoderForFeature(nn.Module):
def __init__(self,
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
bos_idx=0):
super(WrapEncoderForFeature, self).__init__()
self.prepare_encoder = PrepareEncoder(
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx,
word_emb_param_name="src_word_emb_table")
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd)
def forward(self, enc_inputs):
conv_features, src_pos, src_slf_attn_bias = enc_inputs
enc_input = self.prepare_encoder(conv_features, src_pos)
enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output
class WrapEncoder(nn.Module):
"""
embedder + encoder
"""
def __init__(self,
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
bos_idx=0):
super(WrapEncoder, self).__init__()
self.prepare_decoder = PrepareDecoder(
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx)
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd)
def forward(self, enc_inputs):
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = self.prepare_decoder(src_word, src_pos)
enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output
class Encoder(nn.Module):
"""
encoder
"""
def __init__(self,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(Encoder, self).__init__()
self.encoder_layers = nn.ModuleList()
for i in range(n_layer):
encoderLayer = EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)
self.encoder_layers.add_module("layer_%d" % i, encoderLayer)
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
for encoder_layer in self.encoder_layers:
enc_output = encoder_layer(enc_input, attn_bias)
enc_input = enc_output
enc_output = self.processer(enc_output)
return enc_output
class EncoderLayer(nn.Module):
"""
EncoderLayer
"""
def __init__(self,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(
self.preprocesser1(enc_input), None, None, attn_bias)
attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output))
ffn_output = self.postprocesser2(ffn_output, attn_output)
return ffn_output
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention
"""
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = torch.nn.Linear(
in_features=d_model, out_features=d_key * n_head, bias=False)
self.k_fc = torch.nn.Linear(
in_features=d_model, out_features=d_key * n_head, bias=False)
self.v_fc = torch.nn.Linear(
in_features=d_model, out_features=d_value * n_head, bias=False)
self.proj_fc = torch.nn.Linear(
in_features=d_value * n_head, out_features=d_model, bias=False)
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
keys, values = queries, queries
static_kv = False
else: # cross-attention
static_kv = True
q = self.q_fc(queries)
q = torch.reshape(q, shape=[q.size(0), q.size(1), self.n_head, self.d_key])
q = q.permute(0, 2, 1, 3)
if cache is not None and static_kv and "static_k" in cache:
# for encoder-decoder attention in inference and has cached
k = cache["static_k"]
v = cache["static_v"]
else:
k = self.k_fc(keys)
v = self.v_fc(values)
k = torch.reshape(k, shape=[k.size(0), k.size(1), self.n_head, self.d_key])
k = k.permute(0, 2, 1, 3)
v = torch.reshape(v, shape=[v.size(0), v.size(1), self.n_head, self.d_value])
v = v.permute(0, 2, 1, 3)
if cache is not None:
if static_kv and not "static_k" in cache:
# for encoder-decoder attention in inference and has not cached
cache["static_k"], cache["static_v"] = k, v
elif not static_kv:
# for decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
k = torch.cat([cache_k, k], dim=2)
v = torch.cat([cache_v, v], dim=2)
cache["k"], cache["v"] = k, v
return q, k, v
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
q, k, v = self._prepare_qkv(queries, keys, values, cache)
# scale dot product attention
product = torch.matmul(q, k.transpose(2, 3))
product = product * self.d_model**-0.5
if attn_bias is not None:
product += attn_bias
weights = F.softmax(product, dim=-1)
if self.dropout_rate:
weights = F.dropout(
weights, p=self.dropout_rate)
out = torch.matmul(weights, v)
# combine heads
out = out.permute(0, 2, 1, 3)
out = torch.reshape(out, shape=[out.size(0), out.size(1), out.shape[2] * out.shape[3]])
# project to output
out = self.proj_fc(out)
return out
# https://forums.fast.ai/t/lambda-layer/28507/5
class Lambda(nn.Module):
"An easy way to create a pytorch layer for a simple `func`."
def __init__(self, func):
"create a layer that simply calls `func` with `x`"
super().__init__()
self.func=func
def forward(self, x):
return self.func(x)
class LambdaXY(nn.Module):
"An easy way to create a pytorch layer for a simple `func`."
def __init__(self, func):
"create a layer that simply calls `func` with `x`"
super().__init__()
self.func=func
def forward(self, x, y):
return self.func(x, y)
class PrePostProcessLayer(nn.Module):
"""
PrePostProcessLayer
"""
def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd
self.functors = nn.ModuleList()
cur_a_len = 0
cur_n_len = 0
cur_d_len = 0
for cmd in self.process_cmd:
if cmd == "a": # add residual connection
self.functors.add_module('add_res_connect_{}'.format(cur_a_len), LambdaXY(lambda x, y: x + y if y is not None else x))
cur_a_len += 1
elif cmd == "n": # add layer normalization
layerNorm = torch.nn.LayerNorm(normalized_shape=d_model,
elementwise_affine=True,
eps=1e-5)
self.functors.add_module("layer_norm_%d" % cur_n_len,
layerNorm)
cur_n_len += 1
elif cmd == "d": # add dropout
self.functors.add_module('add_drop_{}'.format(cur_d_len),
Lambda(lambda x: F.dropout(
x, p=dropout_rate)
if dropout_rate else x)
)
cur_d_len += 1
def forward(self, x, residual=None):
for i, (cmd, functor) in enumerate(zip(self.process_cmd, self.functors)):
if cmd == "a":
x = functor(x, residual)
else:
x = functor(x)
return x
class PrepareEncoder(nn.Module):
def __init__(self,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0,
bos_idx=0,
word_emb_param_name=None,
pos_enc_param_name=None):
super(PrepareEncoder, self).__init__()
self.src_emb_dim = src_emb_dim
self.src_max_len = src_max_len
self.emb = torch.nn.Embedding(
num_embeddings=self.src_max_len,
embedding_dim=self.src_emb_dim,
sparse=True,
)
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
src_word_emb = src_word.type(torch.float32)
src_word_emb = self.src_emb_dim**0.5 * src_word_emb
src_pos = torch.squeeze(src_pos, dim=-1)
src_pos_enc = self.emb(src_pos.type(torch.int64))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
if self.dropout_rate:
out = F.dropout(
enc_input, p=self.dropout_rate)
else:
out = enc_input
return out
class PrepareDecoder(nn.Module):
def __init__(self,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0,
bos_idx=0,
word_emb_param_name=None,
pos_enc_param_name=None):
super(PrepareDecoder, self).__init__()
self.src_emb_dim = src_emb_dim
"""
self.emb0 = Embedding(num_embeddings=src_vocab_size,
embedding_dim=src_emb_dim)
"""
self.emb0 = torch.nn.Embedding(
num_embeddings=src_vocab_size,
embedding_dim=self.src_emb_dim,
padding_idx=bos_idx,
)
self.emb1 = torch.nn.Embedding(
num_embeddings=src_max_len,
embedding_dim=self.src_emb_dim,
)
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
src_word = torch.squeeze(src_word.type(torch.int64), dim=-1)
src_word_emb = self.emb0(src_word)
src_word_emb = self.src_emb_dim**0.5 * src_word_emb
src_pos = torch.squeeze(src_pos, dim=-1)
src_pos_enc = self.emb1(src_pos)
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
if self.dropout_rate:
out = F.dropout(
enc_input, p=self.dropout_rate)
else:
out = enc_input
return out
class FFN(nn.Module):
"""
Feed-Forward Network
"""
def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
self.fc1 = torch.nn.Linear(
in_features=d_model, out_features=d_inner_hid)
self.fc2 = torch.nn.Linear(
in_features=d_inner_hid, out_features=d_model)
def forward(self, x):
hidden = self.fc1(x)
hidden = F.relu(hidden)
if self.dropout_rate:
hidden = F.dropout(
hidden, p=self.dropout_rate)
out = self.fc2(hidden)
return out
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
"""
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def subsequent_mask(size):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = torch.ones(1, size, size, dtype=torch.float32)
mask_inf = torch.triu(
torch.full(
size=[1, size, size], fill_value=-np.inf, dtype=torch.float32),
diagonal=1)
mask = mask + mask_inf
padding_mask = torch.equal(mask, torch.Tensor(1).type(mask.dtype))
return padding_mask
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def attention(query, key, value, mask=None, dropout=None, attention_map=None):
d_k = query.shape[-1]
scores = torch.matmul(query,
key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
else:
pass
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
self.compress_attention = compress_attention
self.compress_attention_linear = nn.Linear(h, 1)
def forward(self, query, key, value, mask=None, attention_map=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.size(0)
query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
x, attention_map = attention(
query,
key,
value,
mask=mask,
dropout=self.dropout,
attention_map=attention_map)
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x), attention_map
class ResNet(nn.Module):
def __init__(self, num_in, block, layers):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(num_in, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d((2, 2), (2, 2))
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.relu2 = nn.ReLU(inplace=True)
self.layer1_pool = nn.MaxPool2d((2, 2), (2, 2))
self.layer1 = self._make_layer(block, 128, 256, layers[0])
self.layer1_conv = nn.Conv2d(256, 256, 3, 1, 1)
self.layer1_bn = nn.BatchNorm2d(256)
self.layer1_relu = nn.ReLU(inplace=True)
self.layer2_pool = nn.MaxPool2d((2, 2), (2, 2))
self.layer2 = self._make_layer(block, 256, 256, layers[1])
self.layer2_conv = nn.Conv2d(256, 256, 3, 1, 1)
self.layer2_bn = nn.BatchNorm2d(256)
self.layer2_relu = nn.ReLU(inplace=True)
self.layer3_pool = nn.MaxPool2d((2, 2), (2, 2))
self.layer3 = self._make_layer(block, 256, 512, layers[2])
self.layer3_conv = nn.Conv2d(512, 512, 3, 1, 1)
self.layer3_bn = nn.BatchNorm2d(512)
self.layer3_relu = nn.ReLU(inplace=True)
self.layer4_pool = nn.MaxPool2d((2, 2), (2, 2))
self.layer4 = self._make_layer(block, 512, 512, layers[3])
self.layer4_conv2 = nn.Conv2d(512, 1024, 3, 1, 1)
self.layer4_conv2_bn = nn.BatchNorm2d(1024)
self.layer4_conv2_relu = nn.ReLU(inplace=True)
def _make_layer(self, block, inplanes, planes, blocks):
if inplanes != planes:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes, 3, 1, 1),
nn.BatchNorm2d(
planes), )
else:
downsample = None
layers = []
layers.append(block(inplanes, planes, downsample))
for i in range(1, blocks):
layers.append(block(planes, planes, downsample=None))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.layer1_pool(x)
x = self.layer1(x)
x = self.layer1_conv(x)
x = self.layer1_bn(x)
x = self.layer1_relu(x)
x = self.layer2(x)
x = self.layer2_conv(x)
x = self.layer2_bn(x)
x = self.layer2_relu(x)
x = self.layer3(x)
x = self.layer3_conv(x)
x = self.layer3_bn(x)
x = self.layer3_relu(x)
x = self.layer4(x)
x = self.layer4_conv2(x)
x = self.layer4_conv2_bn(x)
x = self.layer4_conv2_relu(x)
return x
class Bottleneck(nn.Module):
def __init__(self, input_dim):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(input_dim, input_dim, 1)
self.bn1 = nn.BatchNorm2d(input_dim)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(input_dim, input_dim, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(input_dim)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, dropout, dim, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, dim, 2).float() *
(-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class PositionwiseFeedForward(nn.Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Generator(nn.Module):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
self.relu = nn.ReLU()
def forward(self, x):
out = self.proj(x)
return out
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
embed = self.lut(x) * math.sqrt(self.d_model)
return embed
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.parameter.Parameter(torch.ones(features))
self.b_2 = nn.parameter.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.mask_multihead = MultiHeadedAttention(
h=16, d_model=1024, dropout=0.1)
self.mul_layernorm1 = LayerNorm(1024)
self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
self.mul_layernorm2 = LayerNorm(1024)
self.pff = PositionwiseFeedForward(1024, 2048)
self.mul_layernorm3 = LayerNorm(1024)
def forward(self, text, conv_feature, attention_map=None):
text_max_length = text.shape[1]
mask = subsequent_mask(text_max_length)
result = text
result = self.mul_layernorm1(result + self.mask_multihead(
text, text, text, mask=mask)[0])
b, c, h, w = conv_feature.shape
conv_feature = conv_feature.view(b, c, h * w).permute(0, 2, 1).contiguous()
word_image_align, attention_map = self.multihead(
result,
conv_feature,
conv_feature,
mask=None,
attention_map=attention_map)
result = self.mul_layernorm2(result + word_image_align)
result = self.mul_layernorm3(result + self.pff(result))
return result, attention_map
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, downsample):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample != None:
residual = self.downsample(residual)
out += residual
out = self.relu(out)
return out
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
def forward(self, input):
conv_result = self.cnn(input)
return conv_result
class Transformer(nn.Module):
def __init__(self, in_channels=1, alphabet='0123456789'):
super(Transformer, self).__init__()
self.alphabet = alphabet
word_n_class = self.get_alphabet_len()
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
self.encoder = Encoder()
self.decoder = Decoder()
self.generator_word_with_upperword = Generator(1024, word_n_class)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def get_alphabet_len(self):
return len(self.alphabet)
def forward(self, image, text_length, text_input, attention_map=None):
if image.shape[1] == 3:
R = image[:, 0:1, :, :]
G = image[:, 1:2, :, :]
B = image[:, 2:3, :, :]
image = 0.299 * R + 0.587 * G + 0.114 * B
conv_feature = self.encoder(image) # batch, 1024, 8, 32
max_length = max(text_length)
text_input = text_input[:, :max_length]
text_embedding = self.embedding_word_with_upperword(
text_input) # batch, text_max_length, 512
if torch.cuda.is_available():
postion_embedding = self.pe(
torch.zeros(text_embedding.shape).cuda()).cuda()
else:
postion_embedding = self.pe(
torch.zeros(text_embedding.shape)) # batch, text_max_length, 512
text_input_with_pe = torch.cat([text_embedding, postion_embedding], 2) # batch, text_max_length, 1024
batch, seq_len, _ = text_input_with_pe.shape
text_input_with_pe, word_attention_map = self.decoder(
text_input_with_pe, conv_feature)
word_decoder_result = self.generator_word_with_upperword(
text_input_with_pe)
if self.training:
total_length = torch.sum(text_length).data
probs_res = torch.zeros([total_length, self.get_alphabet_len()]).type_as(word_decoder_result.data)
start = 0
for index, length in enumerate(text_length):
length = int(length.numpy())
probs_res[start:start + length, :] = word_decoder_result[
index, 0:0 + length, :]
start = start + length
return probs_res, word_attention_map, None
else:
return word_decoder_result
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class TableAttentionHead(nn.Module):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = kwargs.get('max_elem_length', 500)
self.max_cell_num = 500
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len
if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char.type(torch.int64), onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = torch.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
# fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
fea = fea.permute(0, 2, 1)
batch_size = fea.shape[0]
hidden = torch.zeros((batch_size, self.hidden_size))
output_hiddens = []
if self.training and targets is not None:
raise NotImplementedError
else:
temp_elem = torch.zeros([batch_size], dtype=torch.int32)
structure_probs = None
loc_preds = None
elem_onehots = None
outputs = None
alpha = None
max_elem_length = torch.as_tensor(self.max_elem_length)
i = 0
while i < max_elem_length + 1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(torch.unsqueeze(outputs, dim=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(dim=1, keepdim=False)
i += 1
output = torch.cat(output_hiddens, dim=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs, dim=-1)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.permute(0, 2, 1)
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.permute(0, 2, 1)
loc_concat = torch.cat([output, loc_fea], dim=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)
res = torch.add(batch_H_proj, prev_hidden_proj)
res = torch.tanh(res)
e = self.score(res)
alpha = F.softmax(e, dim=1)
alpha = alpha.permute(0, 2, 1)
context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
concat_context = torch.cat([context, char_onehots.float()], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return (cur_hidden, cur_hidden), alpha
class AttentionLSTM(nn.Module):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (torch.zeros((batch_size, self.hidden_size)), torch.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(torch.unsqueeze(hidden[0], dim=1))
output = torch.cat(output_hiddens, dim=1)
probs = self.generator(output)
else:
targets = torch.zeros([batch_size], dtype=torch.int32)
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = torch.unsqueeze(probs_step, dim=1)
else:
probs = torch.cat(
[probs, torch.unsqueeze(
probs_step, dim=1)], dim=1)
next_input = probs_step.argmax(dim=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden[0]), dim=1)
res = torch.add(batch_H_proj, prev_hidden_proj)
res = torch.tanh(res)
e = self.score(res)
alpha = F.softmax(e, dim=1)
alpha = alpha.permute(0, 2, 1)
context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
concat_context = torch.cat([context, char_onehots.float()], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return (cur_hidden, cur_hidden), alpha
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['build_neck']
def build_neck(config):
from .db_fpn import DBFPN, RSEFPN, LKPAN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
from .fpn import FPN
from .fce_fpn import FCEFPN
from .table_fpn import TableFPN
support_dict = ['FPN', 'DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN',
'RSEFPN', 'LKPAN', 'FCEFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
support_dict))
module_class = eval(module_name)(**config)
return module_class
\ No newline at end of file
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.backbones.det_mobilenet_v3 import SEModule
from pytorchocr.modeling.necks.intracl import IntraCLBlock
def hard_swish(x, inplace=True):
return x * F.relu6(x + 3., inplace=inplace) / 6.
class DSConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding,
stride=1,
groups=None,
if_act=True,
act="relu",
**kwargs):
super(DSConv, self).__init__()
if groups == None:
groups = in_channels
self.if_act = if_act
self.act = act
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(
in_channels=in_channels,
out_channels=int(in_channels * 4),
kernel_size=1,
stride=1,
bias=False)
self.bn2 = nn.BatchNorm2d(int(in_channels * 4))
self.conv3 = nn.Conv2d(
in_channels=int(in_channels * 4),
out_channels=out_channels,
kernel_size=1,
stride=1,
bias=False)
self._c = [in_channels, out_channels]
if in_channels != out_channels:
self.conv_end = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
bias=False)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = hard_swish(x)
else:
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
x = self.conv3(x)
if self._c[0] != self._c[1]:
x = x + self.conv_end(inputs)
return x
class DBFPN(nn.Module):
def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__()
self.out_channels = out_channels
self.use_asf = use_asf
self.in2_conv = nn.Conv2d(
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
bias=False)
self.in3_conv = nn.Conv2d(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
bias=False)
self.in4_conv = nn.Conv2d(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
bias=False)
self.in5_conv = nn.Conv2d(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
bias=False)
self.p5_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False)
self.p4_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False)
self.p3_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False)
self.p2_conv = nn.Conv2d(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
bias=False)
if self.use_asf is True:
self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.in5_conv(c5)
in4 = self.in4_conv(c4)
in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2)
out4 = in4 + F.interpolate(
in5, scale_factor=2, mode="nearest", )#align_mode=1) # 1/16
out3 = in3 + F.interpolate(
out4, scale_factor=2, mode="nearest", )#align_mode=1) # 1/8
out2 = in2 + F.interpolate(
out3, scale_factor=2, mode="nearest", )#align_mode=1) # 1/4
p5 = self.p5_conv(in5)
p4 = self.p4_conv(out4)
p3 = self.p3_conv(out3)
p2 = self.p2_conv(out2)
p5 = F.interpolate(p5, scale_factor=8, mode="nearest", )#align_mode=1)
p4 = F.interpolate(p4, scale_factor=4, mode="nearest", )#align_mode=1)
p3 = F.interpolate(p3, scale_factor=2, mode="nearest", )#align_mode=1)
fuse = torch.cat([p5, p4, p3, p2], dim=1)
if self.use_asf is True:
fuse = self.asf(fuse, [p5, p4, p3, p2])
return fuse
class RSELayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
super(RSELayer, self).__init__()
self.out_channels = out_channels
self.in_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
padding=int(kernel_size // 2),
bias=False)
self.se_block = SEModule(self.out_channels)
self.shortcut = shortcut
def forward(self, ins):
x = self.in_conv(ins)
if self.shortcut:
out = x + self.se_block(x)
else:
out = self.se_block(x)
return out
class RSEFPN(nn.Module):
def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
super(RSEFPN, self).__init__()
self.out_channels = out_channels
self.ins_conv = nn.ModuleList()
self.inp_conv = nn.ModuleList()
self.intracl = False
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
self.intracl = kwargs['intracl']
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
for i in range(len(in_channels)):
self.ins_conv.append(
RSELayer(
in_channels[i],
out_channels,
kernel_size=1,
shortcut=shortcut))
self.inp_conv.append(
RSELayer(
out_channels,
out_channels // 4,
kernel_size=3,
shortcut=shortcut))
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.ins_conv[3](c5)
in4 = self.ins_conv[2](c4)
in3 = self.ins_conv[1](c3)
in2 = self.ins_conv[0](c2)
out4 = in4 + F.interpolate(
in5, scale_factor=2, mode="nearest") # 1/16
out3 = in3 + F.interpolate(
out4, scale_factor=2, mode="nearest") # 1/8
out2 = in2 + F.interpolate(
out3, scale_factor=2, mode="nearest") # 1/4
p5 = self.inp_conv[3](in5)
p4 = self.inp_conv[2](out4)
p3 = self.inp_conv[1](out3)
p2 = self.inp_conv[0](out2)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
fuse = torch.cat([p5, p4, p3, p2], dim=1)
return fuse
class LKPAN(nn.Module):
def __init__(self, in_channels, out_channels, mode='large', **kwargs):
super(LKPAN, self).__init__()
self.out_channels = out_channels
self.ins_conv = nn.ModuleList()
self.inp_conv = nn.ModuleList()
# pan head
self.pan_head_conv = nn.ModuleList()
self.pan_lat_conv = nn.ModuleList()
if mode.lower() == 'lite':
p_layer = DSConv
elif mode.lower() == 'large':
p_layer = nn.Conv2d
else:
raise ValueError(
"mode can only be one of ['lite', 'large'], but received {}".
format(mode))
for i in range(len(in_channels)):
self.ins_conv.append(
nn.Conv2d(
in_channels=in_channels[i],
out_channels=self.out_channels,
kernel_size=1,
bias=False))
self.inp_conv.append(
p_layer(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=9,
padding=4,
bias=False))
if i > 0:
self.pan_head_conv.append(
nn.Conv2d(
in_channels=self.out_channels // 4,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
stride=2,
bias=False))
self.pan_lat_conv.append(
p_layer(
in_channels=self.out_channels // 4,
out_channels=self.out_channels // 4,
kernel_size=9,
padding=4,
bias=False))
self.intracl = False
if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
self.intracl = kwargs['intracl']
self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.ins_conv[3](c5)
in4 = self.ins_conv[2](c4)
in3 = self.ins_conv[1](c3)
in2 = self.ins_conv[0](c2)
out4 = in4 + F.interpolate(
in5, scale_factor=2, mode="nearest") # 1/16
out3 = in3 + F.interpolate(
out4, scale_factor=2, mode="nearest") # 1/8
out2 = in2 + F.interpolate(
out3, scale_factor=2, mode="nearest") # 1/4
f5 = self.inp_conv[3](in5)
f4 = self.inp_conv[2](out4)
f3 = self.inp_conv[1](out3)
f2 = self.inp_conv[0](out2)
pan3 = f3 + self.pan_head_conv[0](f2)
pan4 = f4 + self.pan_head_conv[1](pan3)
pan5 = f5 + self.pan_head_conv[2](pan4)
p2 = self.pan_lat_conv[0](f2)
p3 = self.pan_lat_conv[1](pan3)
p4 = self.pan_lat_conv[2](pan4)
p5 = self.pan_lat_conv[3](pan5)
if self.intracl is True:
p5 = self.incl4(p5)
p4 = self.incl3(p4)
p3 = self.incl2(p3)
p2 = self.incl1(p2)
p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
fuse = torch.cat([p5, p4, p3, p2], dim=1)
return fuse
class ASFBlock(nn.Module):
"""
This code is refered from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""
def __init__(self, in_channels, inter_channels, out_features_num=4):
"""
Adaptive Scale Fusion (ASF) block of DBNet++
Args:
in_channels: the number of channels in the input data
inter_channels: the number of middle channels
out_features_num: the number of fused stages
"""
super(ASFBlock, self).__init__()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.out_features_num = out_features_num
self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
self.spatial_scale = nn.Sequential(
#Nx1xHxW
nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=3,
bias=False,
padding=1,
),
nn.ReLU(),
nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=1,
bias=False,
),
nn.Sigmoid())
self.channel_scale = nn.Sequential(
nn.Conv2d(
in_channels=inter_channels,
out_channels=out_features_num,
kernel_size=1,
bias=False,
),
nn.Sigmoid())
def forward(self, fuse_features, features_list):
fuse_features = self.conv(fuse_features)
spatial_x = torch.mean(fuse_features, dim=1, keepdim=True)
attention_scores = self.spatial_scale(spatial_x) + fuse_features
attention_scores = self.channel_scale(attention_scores)
assert len(features_list) == self.out_features_num
out_list = []
for i in range(self.out_features_num):
out_list.append(attention_scores[:, i:i + 1] * features_list[i])
return torch.cat(out_list, dim=1)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
# import paddle
# from paddle import nn
# import paddle.nn.functional as F
# from paddle import ParamAttr
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(
out_channels,)
self.act = act
if act is not None:
self._act = Activation(act)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.act is not None:
x = self._act(x)
return x
class DeConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(
out_channels,
)
self.act = act
if act is not None:
self._act = Activation(act)
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
if self.act is not None:
x = self._act(x)
return x
class EASTFPN(nn.Module):
def __init__(self, in_channels, model_name, **kwargs):
super(EASTFPN, self).__init__()
self.model_name = model_name
if self.model_name == "large":
self.out_channels = 128
else:
self.out_channels = 64
self.in_channels = in_channels[::-1]
self.h1_conv = ConvBNLayer(
in_channels=self.out_channels+self.in_channels[1],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_h_1")
self.h2_conv = ConvBNLayer(
in_channels=self.out_channels+self.in_channels[2],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_h_2")
self.h3_conv = ConvBNLayer(
in_channels=self.out_channels+self.in_channels[3],
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_h_3")
self.g0_deconv = DeConvBNLayer(
in_channels=self.in_channels[0],
out_channels=self.out_channels,
kernel_size=4,
stride=2,
padding=1,
if_act=True,
act='relu',
name="unet_g_0")
self.g1_deconv = DeConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=4,
stride=2,
padding=1,
if_act=True,
act='relu',
name="unet_g_1")
self.g2_deconv = DeConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=4,
stride=2,
padding=1,
if_act=True,
act='relu',
name="unet_g_2")
self.g3_conv = ConvBNLayer(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
if_act=True,
act='relu',
name="unet_g_3")
def forward(self, x):
f = x[::-1]
h = f[0]
g = self.g0_deconv(h)
# h = paddle.concat([g, f[1]], axis=1)
h = torch.cat([g, f[1]], dim=1)
h = self.h1_conv(h)
g = self.g1_deconv(h)
# h = paddle.concat([g, f[2]], axis=1)
h = torch.cat([g, f[2]], dim=1)
h = self.h2_conv(h)
g = self.g2_deconv(h)
# h = paddle.concat([g, f[3]], axis=1)
h = torch.cat([g, f[3]], dim=1)
h = self.h3_conv(h)
g = self.g3_conv(h)
return g
\ No newline at end of file
"""
This code is refer from:
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/ppdet/modeling/necks/fpn.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
# import paddle.nn as nn
# import paddle.nn.functional as F
# from paddle import ParamAttr
# from paddle.nn.initializer import XavierUniform
# from paddle.nn.initializer import Normal
# from paddle.regularizer import L2Decay
__all__ = ['FCEFPN']
class ConvNormLayer(nn.Module):
def __init__(self,
ch_in,
ch_out,
filter_size,
stride,
groups=1,
norm_type='bn',
norm_decay=0.,
norm_groups=32,
lr_scale=1.,
freeze_norm=False,
initializer=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
bias_attr = False
self.conv = nn.Conv2d(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias=bias_attr)
norm_lr = 0. if freeze_norm else 1.
# param_attr = ParamAttr(
# learning_rate=norm_lr,
# regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
# bias_attr = ParamAttr(
# learning_rate=norm_lr,
# regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(
ch_out,
)
elif norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(
ch_out,
)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(
num_groups=norm_groups,
num_channels=ch_out,
affine=bias_attr)
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
return out
class FCEFPN(nn.Module):
"""
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
Args:
in_channels (list[int]): input channels of each level which can be
derived from the output shape of backbone by from_config
out_channels (list[int]): output channel of each level
spatial_scales (list[float]): the spatial scales between input feature
maps and original input image which can be derived from the output
shape of backbone by from_config
has_extra_convs (bool): whether to add extra conv to the last level.
default False
extra_stage (int): the number of extra stages added to the last level.
default 1
use_c5 (bool): Whether to use c5 as the input of extra stage,
otherwise p5 is used. default True
norm_type (string|None): The normalization type in FPN module. If
norm_type is None, norm will not be used after conv and if
norm_type is string, bn, gn, sync_bn are available. default None
norm_decay (float): weight decay for normalization layer weights.
default 0.
freeze_norm (bool): whether to freeze normalization layer.
default False
relu_before_extra_convs (bool): whether to add relu before extra convs.
default False
"""
def __init__(self,
in_channels,
out_channels,
spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
has_extra_convs=False,
extra_stage=1,
use_c5=True,
norm_type=None,
norm_decay=0.,
freeze_norm=False,
relu_before_extra_convs=True):
super(FCEFPN, self).__init__()
self.out_channels = out_channels
for s in range(extra_stage):
spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
self.spatial_scales = spatial_scales
self.has_extra_convs = has_extra_convs
self.extra_stage = extra_stage
self.use_c5 = use_c5
self.relu_before_extra_convs = relu_before_extra_convs
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.lateral_convs = []#nn.ModuleList()
self.lateral_convs_module = nn.ModuleList()
self.fpn_convs = []#nn.ModuleList()
self.fpn_convs_module = nn.ModuleList()
fan = out_channels * 3 * 3
# stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone
# 0 <= st_stage < ed_stage <= 3
st_stage = 4 - len(in_channels)
ed_stage = st_stage + len(in_channels) - 1
for i in range(st_stage, ed_stage + 1):
if i == 3:
lateral_name = 'fpn_inner_res5_sum'
else:
lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
in_c = in_channels[i - st_stage]
if self.norm_type is not None:
# self.lateral_convs_module.add_module(
# lateral_name,
# ConvNormLayer(
# ch_in=in_c,
# ch_out=out_channels,
# filter_size=1,
# stride=1,
# norm_type=self.norm_type,
# norm_decay=self.norm_decay,
# freeze_norm=self.freeze_norm,
# initializer=None))
lateral = ConvNormLayer(
ch_in=in_c,
ch_out=out_channels,
filter_size=1,
stride=1,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
initializer=None)
else:
# self.lateral_convs_module.add_module(
# lateral_name,
# nn.Conv2d(
# in_channels=in_c,
# out_channels=out_channels,
# kernel_size=1,
# )
# )
lateral = nn.Conv2d(
in_channels=in_c,
out_channels=out_channels,
kernel_size=1,
)
self.lateral_convs_module.add_module(lateral_name, lateral)
self.lateral_convs.append(lateral)
for i in range(st_stage, ed_stage + 1):
fpn_name = 'fpn_res{}_sum'.format(i + 2)
fpn_conv_module = nn.Sequential()
if self.norm_type is not None:
# fpn_conv_module.add_module(
# fpn_name,
# ConvNormLayer(
# ch_in=out_channels,
# ch_out=out_channels,
# filter_size=3,
# stride=1,
# norm_type=self.norm_type,
# norm_decay=self.norm_decay,
# freeze_norm=self.freeze_norm,
# initializer=None))
fpn_conv = ConvNormLayer(
ch_in=out_channels,
ch_out=out_channels,
filter_size=3,
stride=1,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
initializer=None)
else:
# fpn_conv_module.add_module(
# fpn_name,
# nn.Conv2d(
# in_channels=out_channels,
# out_channels=out_channels,
# kernel_size=3,
# padding=1,
# )
# )
fpn_conv = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
)
self.fpn_convs_module.add_module(fpn_name, fpn_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
if self.has_extra_convs:
for i in range(self.extra_stage):
lvl = ed_stage + 1 + i
if i == 0 and self.use_c5:
in_c = in_channels[-1]
else:
in_c = out_channels
extra_fpn_name = 'fpn_{}'.format(lvl + 2)
extra_fpn_conv_module = nn.Sequential()
if self.norm_type is not None:
# extra_fpn_conv_module.add_module(
# extra_fpn_name,
# ConvNormLayer(
# ch_in=in_c,
# ch_out=out_channels,
# filter_size=3,
# stride=2,
# norm_type=self.norm_type,
# norm_decay=self.norm_decay,
# freeze_norm=self.freeze_norm,
# initializer=None))
extra_fpn_conv = ConvNormLayer(
ch_in=in_c,
ch_out=out_channels,
filter_size=3,
stride=2,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
initializer=None)
else:
# extra_fpn_conv_module.add_module(
# extra_fpn_name,
# nn.Conv2d(
# in_channels=in_c,
# out_channels=out_channels,
# kernel_size=3,
# stride=2,
# padding=1,
# )
# )
extra_fpn_conv = nn.Conv2d(
in_channels=in_c,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
)
self.fpn_convs_module.add_module(extra_fpn_name, extra_fpn_conv)
self.fpn_convs.append(extra_fpn_conv)
@classmethod
def from_config(cls, cfg, input_shape):
return {
'in_channels': [i.channels for i in input_shape],
'spatial_scales': [1.0 / i.stride for i in input_shape],
}
def forward(self, body_feats):
laterals = []
num_levels = len(body_feats)
for i in range(num_levels):
laterals.append(self.lateral_convs[i](body_feats[i]))
for i in range(1, num_levels):
lvl = num_levels - i
upsample = F.interpolate(
laterals[lvl],
scale_factor=2.,
mode='nearest', )
laterals[lvl - 1] += upsample
fpn_output = []
for lvl in range(num_levels):
fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
if self.extra_stage > 0:
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
if not self.has_extra_convs:
assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs'
fpn_output.append(torch.max_pool2d(fpn_output[-1], 1, stride=2))
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
else:
if self.use_c5:
extra_source = body_feats[-1]
else:
extra_source = fpn_output[-1]
fpn_output.append(self.fpn_convs[num_levels](extra_source))
for i in range(1, self.extra_stage):
if self.relu_before_extra_convs:
fpn_output.append(self.fpn_convs[num_levels + i](F.relu(
fpn_output[-1])))
else:
fpn_output.append(self.fpn_convs[num_levels + i](
fpn_output[-1]))
return fpn_output
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/neck/fpn.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv_BN_ReLU(nn.Module):
def __init__(self,
in_planes,
out_planes,
kernel_size=1,
stride=1,
padding=0):
super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
self.bn = nn.BatchNorm2d(out_planes, momentum=0.1)
self.relu = nn.ReLU()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FPN(nn.Module):
def __init__(self, in_channels, out_channels):
super(FPN, self).__init__()
# Top layer
self.toplayer_ = Conv_BN_ReLU(
in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
# Lateral layers
self.latlayer1_ = Conv_BN_ReLU(
in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer2_ = Conv_BN_ReLU(
in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer3_ = Conv_BN_ReLU(
in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
# Smooth layers
self.smooth1_ = Conv_BN_ReLU(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth2_ = Conv_BN_ReLU(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth3_ = Conv_BN_ReLU(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.out_channels = out_channels * 4
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def _upsample(self, x, scale=1):
return F.interpolate(x, scale_factor=scale, mode='bilinear')
def _upsample_add(self, x, y, scale=1):
return F.interpolate(x, scale_factor=scale, mode='bilinear') + y
def forward(self, x):
f2, f3, f4, f5 = x
p5 = self.toplayer_(f5)
f4 = self.latlayer1_(f4)
p4 = self._upsample_add(p5, f4, 2)
p4 = self.smooth1_(p4)
f3 = self.latlayer2_(f3)
p3 = self._upsample_add(p4, f3, 2)
p3 = self.smooth2_(p3)
f2 = self.latlayer3_(f2)
p2 = self._upsample_add(p3, f2, 2)
p2 = self.smooth3_(p2)
p3 = self._upsample(p3, 2)
p4 = self._upsample(p4, 4)
p5 = self._upsample(p5, 8)
fuse = torch.cat([p2, p3, p4, p5], dim=1)
return fuse
from torch import nn
class IntraCLBlock(nn.Module):
def __init__(self, in_channels=96, reduce_factor=4):
super(IntraCLBlock, self).__init__()
self.channels = in_channels
self.rf = reduce_factor
self.conv1x1_reduce_channel = nn.Conv2d(
self.channels,
self.channels // self.rf,
kernel_size=1,
stride=1,
padding=0)
self.conv1x1_return_channel = nn.Conv2d(
self.channels // self.rf,
self.channels,
kernel_size=1,
stride=1,
padding=0)
self.v_layer_7x1 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 1),
stride=(1, 1),
padding=(3, 0))
self.v_layer_5x1 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 1),
stride=(1, 1),
padding=(2, 0))
self.v_layer_3x1 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 1),
stride=(1, 1),
padding=(1, 0))
self.q_layer_1x7 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 7),
stride=(1, 1),
padding=(0, 3))
self.q_layer_1x5 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 5),
stride=(1, 1),
padding=(0, 2))
self.q_layer_1x3 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(1, 3),
stride=(1, 1),
padding=(0, 1))
# base
self.c_layer_7x7 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(7, 7),
stride=(1, 1),
padding=(3, 3))
self.c_layer_5x5 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2))
self.c_layer_3x3 = nn.Conv2d(
self.channels // self.rf,
self.channels // self.rf,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1))
self.bn = nn.BatchNorm2d(self.channels)
self.relu = nn.ReLU()
def forward(self, x):
x_new = self.conv1x1_reduce_channel(x)
x_7_c = self.c_layer_7x7(x_new)
x_7_v = self.v_layer_7x1(x_new)
x_7_q = self.q_layer_1x7(x_new)
x_7 = x_7_c + x_7_v + x_7_q
x_5_c = self.c_layer_5x5(x_7)
x_5_v = self.v_layer_5x1(x_7)
x_5_q = self.q_layer_1x5(x_7)
x_5 = x_5_c + x_5_v + x_5_q
x_3_c = self.c_layer_3x3(x_5)
x_3_v = self.v_layer_3x1(x_5)
x_3_q = self.q_layer_1x3(x_5)
x_3 = x_3_c + x_3_v + x_3_q
x_relation = self.conv1x1_return_channel(x_3)
x_relation = self.bn(x_relation)
x_relation = self.relu(x_relation)
return x + x_relation
def build_intraclblock_list(num_block):
IntraCLBlock_list = nn.ModuleList()
for i in range(num_block):
IntraCLBlock_list.append(IntraCLBlock())
return IntraCLBlock_list
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm2d(out_channels)
self.act = act
if self.act is not None:
self._act = Activation(act_type=self.act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class DeConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = act
if self.act is not None:
self._act = Activation(act_type=self.act, inplace=True)
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
if self.act is not None:
x = self._act(x)
return x
class PGFPN(nn.Module):
def __init__(self, in_channels, **kwargs):
super(PGFPN, self).__init__()
num_inputs = [2048, 2048, 1024, 512, 256]
num_outputs = [256, 256, 192, 192, 128]
self.out_channels = 128
self.conv_bn_layer_1 = ConvBNLayer(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
act=None,
name='FPN_d1')
self.conv_bn_layer_2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act=None,
name='FPN_d2')
self.conv_bn_layer_3 = ConvBNLayer(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=1,
act=None,
name='FPN_d3')
self.conv_bn_layer_4 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=2,
act=None,
name='FPN_d4')
self.conv_bn_layer_5 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d5')
self.conv_bn_layer_6 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
act=None,
name='FPN_d6')
self.conv_bn_layer_7 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d7')
self.conv_bn_layer_8 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=1,
stride=1,
act=None,
name='FPN_d8')
self.conv_h0 = ConvBNLayer(
in_channels=num_inputs[0],
out_channels=num_outputs[0],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(0))
self.conv_h1 = ConvBNLayer(
in_channels=num_inputs[1],
out_channels=num_outputs[1],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(1))
self.conv_h2 = ConvBNLayer(
in_channels=num_inputs[2],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(2))
self.conv_h3 = ConvBNLayer(
in_channels=num_inputs[3],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(3))
self.conv_h4 = ConvBNLayer(
in_channels=num_inputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(4))
self.dconv0 = DeConvBNLayer(
in_channels=num_outputs[0],
out_channels=num_outputs[0 + 1],
name="dconv_{}".format(0))
self.dconv1 = DeConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[1 + 1],
act=None,
name="dconv_{}".format(1))
self.dconv2 = DeConvBNLayer(
in_channels=num_outputs[2],
out_channels=num_outputs[2 + 1],
act=None,
name="dconv_{}".format(2))
self.dconv3 = DeConvBNLayer(
in_channels=num_outputs[3],
out_channels=num_outputs[3 + 1],
act=None,
name="dconv_{}".format(3))
self.conv_g1 = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(1))
self.conv_g2 = ConvBNLayer(
in_channels=num_outputs[2],
out_channels=num_outputs[2],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(2))
self.conv_g3 = ConvBNLayer(
in_channels=num_outputs[3],
out_channels=num_outputs[3],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(3))
self.conv_g4 = ConvBNLayer(
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(4))
self.convf = ConvBNLayer(
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_f{}".format(4))
def forward(self, x):
c0, c1, c2, c3, c4, c5, c6 = x
# FPN_Down_Fusion
f = [c0, c1, c2]
g = [None, None, None]
h = [None, None, None]
h[0] = self.conv_bn_layer_1(f[0])
h[1] = self.conv_bn_layer_2(f[1])
h[2] = self.conv_bn_layer_3(f[2])
g[0] = self.conv_bn_layer_4(h[0])
g[1] = torch.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_bn_layer_5(g[1])
g[1] = self.conv_bn_layer_6(g[1])
g[2] = torch.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_bn_layer_7(g[2])
f_down = self.conv_bn_layer_8(g[2])
# FPN UP Fusion
f1 = [c6, c5, c4, c3, c2]
g = [None, None, None, None, None]
h = [None, None, None, None, None]
h[0] = self.conv_h0(f1[0])
h[1] = self.conv_h1(f1[1])
h[2] = self.conv_h2(f1[2])
h[3] = self.conv_h3(f1[3])
h[4] = self.conv_h4(f1[4])
g[0] = self.dconv0(h[0])
g[1] = torch.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_g1(g[1])
g[1] = self.dconv1(g[1])
g[2] = torch.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_g2(g[2])
g[2] = self.dconv2(g[2])
g[3] = torch.add(g[2], h[3])
g[3] = F.relu(g[3])
g[3] = self.conv_g3(g[3])
g[3] = self.dconv3(g[3])
g[4] = torch.add(g[3], h[4])
g[4] = F.relu(g[4])
g[4] = self.conv_g4(g[4])
f_up = self.convf(g[4])
f_common = torch.add(f_down, f_up)
f_common = F.relu(f_common)
return f_common
import os, sys
import torch
import torch.nn as nn
from pytorchocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer
class Im2Seq(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
# assert H == 1
x = x.squeeze(dim=2)
# x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
x = x.permute(0,2,1)
return x
class EncoderWithRNN_(nn.Module):
def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN_, self).__init__()
self.out_channels = hidden_size * 2
self.rnn1 = nn.LSTM(in_channels, hidden_size, bidirectional=False, batch_first=True, num_layers=2)
self.rnn2 = nn.LSTM(in_channels, hidden_size, bidirectional=False, batch_first=True, num_layers=2)
def forward(self, x):
self.rnn1.flatten_parameters()
self.rnn2.flatten_parameters()
out1, h1 = self.rnn1(x)
out2, h2 = self.rnn2(torch.flip(x, [1]))
return torch.cat([out1, torch.flip(out2, [1])], 2)
class EncoderWithRNN(nn.Module):
def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN, self).__init__()
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(
in_channels, hidden_size, num_layers=2, batch_first=True, bidirectional=True) # batch_first:=True
def forward(self, x):
x, _ = self.lstm(x)
return x
class EncoderWithFC(nn.Module):
def __init__(self, in_channels, hidden_size):
super(EncoderWithFC, self).__init__()
self.out_channels = hidden_size
self.fc = nn.Linear(
in_channels,
hidden_size,
bias=True,
)
def forward(self, x):
x = self.fc(x)
return x
class EncoderWithSVTR(nn.Module):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
kernel_size=[3,3],
attn_drop_rate=0.1,
drop_path=0.,
qk_scale=None):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
act='swish')
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act='swish')
self.svtr_block = nn.ModuleList([
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer='Global',
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer='swish',
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer='nn.LayerNorm',
epsilon=1e-05,
prenorm=False) for i in range(depth)
])
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.conv3 = ConvBNLayer(
hidden_dims, in_channels, kernel_size=1, act='swish')
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act='swish')
self.conv1x1 = ConvBNLayer(
in_channels // 8, dims, kernel_size=1, act='swish')
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).permute(0, 2, 1)
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
z = self.conv3(z)
z = torch.cat((h, z), dim=1)
z = self.conv1x1(self.conv4(z))
return z
class SequenceEncoder(nn.Module):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == 'reshape':
self.only_reshape = True
else:
support_encoder_dict = {
'reshape': Im2Seq,
'fc': EncoderWithFC,
'rnn': EncoderWithRNN,
'svtr': EncoderWithSVTR,
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, **kwargs)
else:
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
if self.encoder_type != 'svtr':
x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x)
x = self.encoder_reshape(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment