"doc/vscode:/vscode.git/clone" did not exist on "b90989358507651a6273a7bdfd7b2c9e7f4e6004"
Unverified Commit 006d84bf authored by 崔浩's avatar 崔浩 Committed by GitHub
Browse files

Merge branch 'PaddlePaddle:dygraph' into dygraph

parents 302ca30c 8beeb84c
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
class SAREncoder(nn.Layer):
"""
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def __init__(self,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
enc_gru=False,
d_model=512,
d_enc=512,
mask=True,
**kwargs):
super().__init__()
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_drop_rnn, (int, float))
assert 0 <= enc_drop_rnn < 1.0
assert isinstance(enc_gru, bool)
assert isinstance(d_model, int)
assert isinstance(d_enc, int)
assert isinstance(mask, bool)
self.enc_bi_rnn = enc_bi_rnn
self.enc_drop_rnn = enc_drop_rnn
self.mask = mask
# LSTM Encoder
if enc_bi_rnn:
direction = 'bidirectional'
else:
direction = 'forward'
kwargs = dict(
input_size=d_model,
hidden_size=d_enc,
num_layers=2,
time_major=False,
dropout=enc_drop_rnn,
direction=direction)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def forward(self, feat, img_metas=None):
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
h_feat = feat.shape[2] # bsz c h w
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
if valid_ratios is not None:
valid_hf = []
T = holistic_feat.shape[1]
for i, valid_ratio in enumerate(valid_ratios):
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = paddle.stack(valid_hf, axis=0)
else:
valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat
class BaseDecoder(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
def forward_train(self, feat, out_enc, targets, img_metas):
raise NotImplementedError
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
def forward(self,
feat,
out_enc,
label=None,
img_metas=None,
train_mode=True):
self.train_mode = train_mode
if train_mode:
return self.forward_train(feat, out_enc, label, img_metas)
return self.forward_test(feat, out_enc, img_metas)
class ParallelSARDecoder(BaseDecoder):
"""
Args:
out_channels (int): Output class number.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder.
dec_gru (bool): If True, use GRU, else LSTM in decoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
d_k (int): Dim of channels of attention module.
pred_dropout (float): Dropout probability of prediction layer.
max_seq_len (int): Maximum sequence length for decoding.
mask (bool): If True, mask padding in feature map.
start_idx (int): Index of start token.
padding_idx (int): Index of padding token.
pred_concat (bool): If True, concat glimpse feature from
attention with holistic feature and hidden state.
"""
def __init__(
self,
out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
dec_gru=False,
d_model=512,
d_enc=512,
d_k=64,
pred_dropout=0.1,
max_text_length=30,
mask=True,
pred_concat=True,
**kwargs):
super().__init__()
self.num_classes = out_channels
self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k
self.start_idx = out_channels - 2
self.padding_idx = out_channels - 1
self.max_seq_len = max_text_length
self.mask = mask
self.pred_concat = pred_concat
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
# 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2D(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer
if dec_bi_rnn:
direction = 'bidirectional'
else:
direction = 'forward'
kwargs = dict(
input_size=encoder_rnn_out_size,
hidden_size=encoder_rnn_out_size,
num_layers=2,
time_major=False,
dropout=dec_drop_rnn,
direction=direction)
if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs)
else:
self.rnn_decoder = nn.LSTM(**kwargs)
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes,
encoder_rnn_out_size,
padding_idx=self.padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = self.num_classes - 1
if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
def _2d_attention(self,
decoder_input,
feat,
holistic_feat,
valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
bsz, seq_len, attn_size = attn_query.shape
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
# (bsz, seq_len + 1, attn_size, 1, 1)
attn_key = self.conv3x3_1(feat)
# bsz * attn_size * h * w
attn_key = attn_key.unsqueeze(1)
# bsz * 1 * attn_size * h * w
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight = self.conv1x1_2(attn_weight)
# bsz * (seq_len + 1) * h * w * 1
bsz, T, h, w, c = attn_weight.shape
assert c == 1
if valid_ratios is not None:
# cal mask of attention weight
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
if valid_width < w:
attn_weight[i, :, :, valid_width:, :] = float('-inf')
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = F.softmax(attn_weight, axis=-1)
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
(3, 4),
keepdim=False)
# bsz * (seq_len + 1) * C
# Linear transformation
if self.pred_concat:
hf_c = holistic_feat.shape[-1]
holistic_feat = paddle.expand(
holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
if self.train_mode:
y = self.pred_dropout(y)
return y
def forward_train(self, feat, out_enc, label, img_metas):
'''
img_metas: [label, valid_ratio]
'''
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
label = label.cuda()
lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C
out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios)
# bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas):
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
seq_len = self.max_seq_len
bsz = feat.shape[0]
start_token = paddle.full(
(bsz, ), fill_value=self.start_idx, dtype='int64')
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
emb_dim = start_token.shape[1]
start_token = start_token.unsqueeze(1)
start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = paddle.concat((out_enc, start_token), axis=1)
# bsz * (seq_len + 1) * emb_dim
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim
if i < seq_len:
decoder_input[:, i + 1, :] = char_embedding
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs
class SARHead(nn.Layer):
def __init__(self,
out_channels,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
enc_gru=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
dec_gru=False,
d_k=512,
pred_dropout=0.1,
max_text_length=30,
pred_concat=True,
**kwargs):
super(SARHead, self).__init__()
# encoder module
self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
# decoder module
self.decoder = ParallelSARDecoder(
out_channels=out_channels,
enc_bi_rnn=enc_bi_rnn,
dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru,
d_k=d_k,
pred_dropout=pred_dropout,
max_text_length=max_text_length,
pred_concat=pred_concat)
def forward(self, feat, targets=None):
'''
img_metas: [label, valid_ratio]
'''
holistic_feat = self.encoder(feat, targets) # bsz c
if self.training:
label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
if not self.training:
final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
train_mode=False)
# (bsz, seq_len, num_classes)
return final_out
...@@ -22,7 +22,8 @@ def build_neck(config): ...@@ -22,7 +22,8 @@ def build_neck(config):
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
from .pg_fpn import PGFPN from .pg_fpn import PGFPN
from .table_fpn import TableFPN from .table_fpn import TableFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] from .fpn import FPN
support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
import paddle
import math
import paddle.nn.functional as F
class Conv_BN_ReLU(nn.Layer):
def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
bias_attr=False)
self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
self.relu = nn.ReLU()
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Normal(0, math.sqrt(2. / n)))
elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(1.0))
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0))
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FPN(nn.Layer):
def __init__(self, in_channels, out_channels):
super(FPN, self).__init__()
# Top layer
self.toplayer_ = Conv_BN_ReLU(in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
# Lateral layers
self.latlayer1_ = Conv_BN_ReLU(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer2_ = Conv_BN_ReLU(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer3_ = Conv_BN_ReLU(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
# Smooth layers
self.smooth1_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth2_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth3_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.out_channels = out_channels * 4
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Normal(0,
math.sqrt(2. / n)))
elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Constant(1.0))
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0))
def _upsample(self, x, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear')
def _upsample_add(self, x, y, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear') + y
def forward(self, x):
f2, f3, f4, f5 = x
p5 = self.toplayer_(f5)
f4 = self.latlayer1_(f4)
p4 = self._upsample_add(p5, f4,2)
p4 = self.smooth1_(p4)
f3 = self.latlayer2_(f3)
p3 = self._upsample_add(p4, f3,2)
p3 = self.smooth2_(p3)
f2 = self.latlayer3_(f2)
p2 = self._upsample_add(p3, f2,2)
p2 = self.smooth3_(p2)
p3 = self._upsample(p3, 2)
p4 = self._upsample(p4, 4)
p5 = self._upsample(p5, 8)
fuse = paddle.concat([p2, p3, p4, p5], axis=1)
return fuse
\ No newline at end of file
...@@ -51,7 +51,7 @@ class EncoderWithFC(nn.Layer): ...@@ -51,7 +51,7 @@ class EncoderWithFC(nn.Layer):
super(EncoderWithFC, self).__init__() super(EncoderWithFC, self).__init__()
self.out_channels = hidden_size self.out_channels = hidden_size
weight_attr, bias_attr = get_para_bias_attr( weight_attr, bias_attr = get_para_bias_attr(
l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea') l2_decay=0.00001, k=in_channels)
self.fc = nn.Linear( self.fc = nn.Linear(
in_channels, in_channels,
hidden_size, hidden_size,
......
...@@ -17,8 +17,9 @@ __all__ = ['build_transform'] ...@@ -17,8 +17,9 @@ __all__ = ['build_transform']
def build_transform(config): def build_transform(config):
from .tps import TPS from .tps import TPS
from .stn import STN_ON
support_dict = ['TPS'] support_dict = ['TPS', 'STN_ON']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
from .tps_spatial_transformer import TPSSpatialTransformer
def conv3x3_block(in_channels, out_channels, stride=1):
n = 3 * 3 * out_channels
w = math.sqrt(2. / n)
conv_layer = nn.Conv2D(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
weight_attr=nn.initializer.Normal(
mean=0.0, std=w),
bias_attr=nn.initializer.Constant(0))
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
return block
class STN(nn.Layer):
def __init__(self, in_channels, num_ctrlpoints, activation='none'):
super(STN, self).__init__()
self.in_channels = in_channels
self.num_ctrlpoints = num_ctrlpoints
self.activation = activation
self.stn_convnet = nn.Sequential(
conv3x3_block(in_channels, 32), #32x64
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(32, 64), #16x32
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(64, 128), # 8*16
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(128, 256), # 4*8
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(256, 256), # 2*4,
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(256, 256)) # 1*2
self.stn_fc1 = nn.Sequential(
nn.Linear(
2 * 256,
512,
weight_attr=nn.initializer.Normal(0, 0.001),
bias_attr=nn.initializer.Constant(0)),
nn.BatchNorm1D(512),
nn.ReLU())
fc2_bias = self.init_stn()
self.stn_fc2 = nn.Linear(
512,
num_ctrlpoints * 2,
weight_attr=nn.initializer.Constant(0.0),
bias_attr=nn.initializer.Assign(fc2_bias))
def init_stn(self):
margin = 0.01
sampling_num_per_side = int(self.num_ctrlpoints / 2)
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
ctrl_points = np.concatenate(
[ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
if self.activation == 'none':
pass
elif self.activation == 'sigmoid':
ctrl_points = -np.log(1. / ctrl_points - 1.)
ctrl_points = paddle.to_tensor(ctrl_points)
fc2_bias = paddle.reshape(
ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
return fc2_bias
def forward(self, x):
x = self.stn_convnet(x)
batch_size, _, h, w = x.shape
x = paddle.reshape(x, shape=(batch_size, -1))
img_feat = self.stn_fc1(x)
x = self.stn_fc2(0.1 * img_feat)
if self.activation == 'sigmoid':
x = F.sigmoid(x)
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
return img_feat, x
class STN_ON(nn.Layer):
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
num_control_points, tps_margins, stn_activation):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN(in_channels=in_channels,
num_ctrlpoints=num_control_points,
activation=stn_activation)
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
return x
...@@ -231,7 +231,8 @@ class GridGenerator(nn.Layer): ...@@ -231,7 +231,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """ """ Return inv_delta_C which is needed to calculate T """
F = self.F F = self.F
hat_eye = paddle.eye(F, dtype='float64') # F x F hat_eye = paddle.eye(F, dtype='float64') # F x F
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = paddle.norm(
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = (hat_C**2) * paddle.log(hat_C) hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3 delta_C = paddle.concat( # F+3 x F+3
[ [
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
import itertools
def grid_sample(input, grid, canvas=None):
input.stop_gradient = False
output = F.grid_sample(input, grid)
if canvas is None:
return output
else:
input_mask = paddle.ones(shape=input.shape)
output_mask = F.grid_sample(input_mask, grid)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
N = input_points.shape[0]
M = control_points.shape[0]
pairwise_diff = paddle.reshape(
input_points, shape=[N, 1, 2]) - paddle.reshape(
control_points, shape=[1, M, 2])
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
1]
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix
repr_matrix[mask] = 0
return repr_matrix
# output_ctrl_pts are specified, according to our task.
def build_output_control_points(num_control_points, margins):
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
output_ctrl_pts_arr = np.concatenate(
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
return output_ctrl_pts
class TPSSpatialTransformer(nn.Layer):
def __init__(self,
output_image_size=None,
num_control_points=None,
margins=None):
super(TPSSpatialTransformer, self).__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
target_control_points = build_output_control_points(num_control_points,
margins)
N = num_control_points
# create padded kernel matrix
forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
target_control_partial_repr = compute_partial_repr(
target_control_points, target_control_points)
target_control_partial_repr = paddle.cast(target_control_partial_repr,
forward_kernel.dtype)
forward_kernel[:N, :N] = target_control_partial_repr
forward_kernel[:N, -3] = 1
forward_kernel[-3, :N] = 1
target_control_points = paddle.cast(target_control_points,
forward_kernel.dtype)
forward_kernel[:N, -2:] = target_control_points
forward_kernel[-2:, :N] = paddle.transpose(
target_control_points, perm=[1, 0])
# compute inverse matrix
inverse_kernel = paddle.inverse(forward_kernel)
# create target cordinate matrix
HW = self.target_height * self.target_width
target_coordinate = list(
itertools.product(
range(self.target_height), range(self.target_width)))
target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
Y, X = paddle.split(
target_coordinate, target_coordinate.shape[1], axis=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
target_coordinate = paddle.concat(
[X, Y], axis=1) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(
target_coordinate, target_control_points)
target_coordinate_repr = paddle.concat(
[
target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
target_coordinate
],
axis=1)
# register precomputed matrices
self.inverse_kernel = inverse_kernel
self.padding_matrix = paddle.zeros(shape=[3, 2])
self.target_coordinate_repr = target_coordinate_repr
self.target_control_points = target_control_points
def forward(self, input, source_control_points):
assert source_control_points.ndimension() == 3
assert source_control_points.shape[1] == self.num_control_points
assert source_control_points.shape[2] == 2
batch_size = paddle.shape(source_control_points)[0]
self.padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2])
Y = paddle.concat([source_control_points, self.padding_matrix], 1)
mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
source_coordinate = paddle.matmul(self.target_coordinate_repr,
mapping_matrix)
grid = paddle.reshape(
source_coordinate,
shape=[-1, self.target_height, self.target_width, 2])
grid = paddle.clip(grid, 0,
1) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None)
return output_maps, source_coordinate
...@@ -127,3 +127,34 @@ class RMSProp(object): ...@@ -127,3 +127,34 @@ class RMSProp(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=parameters)
return opt return opt
class Adadelta(object):
def __init__(self,
learning_rate=0.001,
epsilon=1e-08,
rho=0.95,
parameter_list=None,
weight_decay=None,
grad_clip=None,
name=None,
**kwargs):
self.learning_rate = learning_rate
self.epsilon = epsilon
self.rho = rho
self.parameter_list = parameter_list
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.name = name
def __call__(self, parameters):
opt = optim.Adadelta(
learning_rate=self.learning_rate,
epsilon=self.epsilon,
rho=self.rho,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
parameters=parameters)
return opt
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
import platform
__all__ = ['build_post_process'] __all__ = ['build_post_process']
...@@ -25,17 +26,22 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess ...@@ -25,17 +26,22 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
if platform.system() != "Windows":
# pse is not support in Windows
from .pse_postprocess import PSEPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
'DistillationCTCLabelDecode', 'TableLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess' 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pse_postprocess import PSEPostProcess
\ No newline at end of file
## 编译
code from https://github.com/whai362/pan_pp.pytorch
```python
python3 setup.py build_ext --inplace
```
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import subprocess
python_path = sys.executable
if subprocess.call('cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'.format(python_path), shell=True) != 0:
raise RuntimeError('Cannot compile pse: {}'.format(os.path.dirname(os.path.realpath(__file__))))
from .pse import pse
\ No newline at end of file
import numpy as np
import cv2
cimport numpy as np
cimport cython
cimport libcpp
cimport libcpp.pair
cimport libcpp.queue
from libcpp.pair cimport *
from libcpp.queue cimport *
@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
np.ndarray[np.int32_t, ndim=2] label,
int kernel_num,
int label_num,
float min_area=0):
cdef np.ndarray[np.int32_t, ndim=2] pred
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
for label_idx in range(1, label_num):
if np.sum(label == label_idx) < min_area:
label[label == label_idx] = 0
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
cdef np.int16_t* dx = [-1, 1, 0, 0]
cdef np.int16_t* dy = [0, 0, -1, 1]
cdef np.int16_t tmpx, tmpy
points = np.array(np.where(label > 0)).transpose((1, 0))
for point_idx in range(points.shape[0]):
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = label[tmpx, tmpy]
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
cdef int cur_label
for kernel_idx in range(kernel_num - 1, -1, -1):
while not que.empty():
cur = que.front()
que.pop()
cur_label = pred[cur.first, cur.second]
is_edge = True
for j in range(4):
tmpx = cur.first + dx[j]
tmpy = cur.second + dy[j]
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
continue
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
continue
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
pred[tmpx, tmpy] = cur_label
is_edge = False
if is_edge:
nxt_que.push(cur)
que, nxt_que = nxt_que, que
return pred
def pse(kernels, min_area):
kernel_num = kernels.shape[0]
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
\ No newline at end of file
from distutils.core import setup, Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules=cythonize(Extension(
'pse',
sources=['pse.pyx'],
language='c++',
include_dirs=[numpy.get_include()],
library_dirs=[],
libraries=[],
extra_compile_args=['-O3'],
extra_link_args=[]
)))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import paddle
from paddle.nn import functional as F
from ppocr.postprocess.pse_postprocess.pse import pse
class PSEPostProcess(object):
"""
The post process for PSE.
"""
def __init__(self,
thresh=0.5,
box_thresh=0.85,
min_area=16,
box_type='box',
scale=4,
**kwargs):
assert box_type in ['box', 'poly'], 'Only box and poly is supported'
self.thresh = thresh
self.box_thresh = box_thresh
self.min_area = min_area
self.box_type = box_type
self.scale = scale
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, paddle.Tensor):
pred = paddle.to_tensor(pred)
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear')
score = F.sigmoid(pred[:, 0, :, :])
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()
kernels = kernels.numpy().astype(np.uint8)
boxes_batch = []
for batch_index in range(pred.shape[0]):
boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index])
boxes_batch.append({'points': boxes, 'scores': scores})
return boxes_batch
def boxes_from_bitmap(self, score, kernels, shape):
label = pse(kernels, self.min_area)
return self.generate_box(score, label, shape)
def generate_box(self, score, label, shape):
src_h, src_w, ratio_h, ratio_w = shape
label_num = np.max(label) + 1
boxes = []
scores = []
for i in range(1, label_num):
ind = label == i
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
if points.shape[0] < self.min_area:
label[ind] = 0
continue
score_i = np.mean(score[ind])
if score_i < self.box_thresh:
label[ind] = 0
continue
if self.box_type == 'box':
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
elif self.box_type == 'poly':
box_height = np.max(points[:, 1]) + 10
box_width = np.max(points[:, 0]) + 10
mask = np.zeros((box_height, box_width), np.uint8)
mask[points[:, 1], points[:, 0]] = 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bbox = np.squeeze(contours[0], 1)
else:
raise NotImplementedError
bbox[:, 0] = np.clip(
np.round(bbox[:, 0] / ratio_w), 0, src_w)
bbox[:, 1] = np.clip(
np.round(bbox[:, 1] / ratio_h), 0, src_h)
boxes.append(bbox)
scores.append(score_i)
return boxes, scores
...@@ -15,38 +15,21 @@ import numpy as np ...@@ -15,38 +15,21 @@ import numpy as np
import string import string
import paddle import paddle
from paddle.nn import functional as F from paddle.nn import functional as F
import re
class BaseRecLabelDecode(object): class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self, character_dict_path=None, use_space_char=False):
character_dict_path=None,
character_type='ch',
use_space_char=False):
support_character_type = [
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
if character_type == "en": self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type == "EN_symbol": else:
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -56,9 +39,6 @@ class BaseRecLabelDecode(object): ...@@ -56,9 +39,6 @@ class BaseRecLabelDecode(object):
self.character_str.append(" ") self.character_str.append(" ")
dict_character = list(self.character_str) dict_character = list(self.character_str)
else:
raise NotImplementedError
self.character_type = character_type
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
for i, char in enumerate(dict_character): for i, char in enumerate(dict_character):
...@@ -101,15 +81,14 @@ class BaseRecLabelDecode(object): ...@@ -101,15 +81,14 @@ class BaseRecLabelDecode(object):
class CTCLabelDecode(BaseRecLabelDecode): class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self, character_dict_path=None, use_space_char=False,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs): **kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path, super(CTCLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
...@@ -133,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -133,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
def __init__(self, def __init__(self,
character_dict_path=None, character_dict_path=None,
character_type='ch',
use_space_char=False, use_space_char=False,
model_name=["student"], model_name=["student"],
key=None, key=None,
**kwargs): **kwargs):
super(DistillationCTCLabelDecode, self).__init__( super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
character_dict_path, character_type, use_space_char) use_space_char)
if not isinstance(model_name, list): if not isinstance(model_name, list):
model_name = [model_name] model_name = [model_name]
self.model_name = model_name self.model_name = model_name
...@@ -156,16 +134,77 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -156,16 +134,77 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
return output return output
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, paddle.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, paddle.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
else:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == 3: # end
break
try:
char_list.append(self.character[int(text_index[batch_idx][
idx])])
except:
continue
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list)))
return result_list
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self, character_dict_path=None, use_space_char=False,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs): **kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path, super(AttnLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
...@@ -239,16 +278,91 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -239,16 +278,91 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx return idx
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
return dict_character
def get_ignored_tokens(self):
end_idx = self.get_beg_end_flag_idx("eos")
return [end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "sos":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "eos":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
[end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx = preds["rec_pred"]
if isinstance(preds_idx, paddle.Tensor):
preds_idx = preds_idx.numpy()
if "rec_pred_scores" in preds:
preds_idx = preds["rec_pred"]
preds_prob = preds["rec_pred_scores"]
else:
preds_idx = preds["rec_pred"].argmax(axis=2)
preds_prob = preds["rec_pred"].max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
class SRNLabelDecode(BaseRecLabelDecode): class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self, character_dict_path=None, use_space_char=False,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs): **kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path, super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25) self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
...@@ -324,10 +438,9 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -324,10 +438,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class TableLabelDecode(object): class TableLabelDecode(object):
""" """ """ """
def __init__(self, def __init__(self, character_dict_path, **kwargs):
character_dict_path, list_character, list_elem = self.load_char_elem_dict(
**kwargs): character_dict_path)
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character) list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) list_elem = self.add_special_char(list_elem)
self.dict_character = {} self.dict_character = {}
...@@ -346,7 +459,8 @@ class TableLabelDecode(object): ...@@ -346,7 +459,8 @@ class TableLabelDecode(object):
list_elem = [] list_elem = []
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t") substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0]) character_num = int(substr[0])
elem_num = int(substr[1]) elem_num = int(substr[1])
for cno in range(1, 1 + character_num): for cno in range(1, 1 + character_num):
...@@ -366,14 +480,14 @@ class TableLabelDecode(object): ...@@ -366,14 +480,14 @@ class TableLabelDecode(object):
def __call__(self, preds): def __call__(self, preds):
structure_probs = preds['structure_probs'] structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds'] loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor): if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy() structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor): if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy() loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2) structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2) structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_probs, 'elem') structure_idx, structure_probs, 'elem')
res_html_code_list = [] res_html_code_list = []
res_loc_list = [] res_loc_list = []
batch_num = len(structure_str) batch_num = len(structure_str)
...@@ -388,8 +502,13 @@ class TableLabelDecode(object): ...@@ -388,8 +502,13 @@ class TableLabelDecode(object):
res_loc = np.array(res_loc) res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code) res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc) res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list, return {
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str} 'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem): def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index. """convert text-label into text-index.
...@@ -454,3 +573,79 @@ class TableLabelDecode(object): ...@@ -454,3 +573,79 @@ class TableLabelDecode(object):
assert False, "Unsupport type %s in char_or_elem" \ assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem % char_or_elem
return idx return idx
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx == 0:
continue
else:
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if self.rm_symbol:
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
return [self.padding_idx]
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
\ No newline at end of file
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
_
`
~
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
EPS = 1e-6
def iou_single(a, b, mask, n_class):
valid = mask == 1
a = a.masked_select(valid)
b = b.masked_select(valid)
miou = []
for i in range(n_class):
if a.shape == [0] and a.shape==b.shape:
inter = paddle.to_tensor(0.0)
union = paddle.to_tensor(0.0)
else:
inter = ((a == i).logical_and(b == i)).astype('float32')
union = ((a == i).logical_or(b == i)).astype('float32')
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
miou = sum(miou) / len(miou)
return miou
def iou(a, b, mask, n_class=2, reduce=True):
batch_size = a.shape[0]
a = a.reshape([batch_size, -1])
b = b.reshape([batch_size, -1])
mask = mask.reshape([batch_size, -1])
iou = paddle.zeros((batch_size,), dtype='float32')
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
if reduce:
iou = paddle.mean(iou)
return iou
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment