Unverified Commit 465ef3bf authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into bm_dyg

parents bf9f93f7 bc999986
...@@ -23,10 +23,10 @@ import paddle.nn.functional as F ...@@ -23,10 +23,10 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
def get_bias_attr(k, name): def get_bias_attr(k):
stdv = 1.0 / math.sqrt(k * 1.0) stdv = 1.0 / math.sqrt(k * 1.0)
initializer = paddle.nn.initializer.Uniform(-stdv, stdv) initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr") bias_attr = ParamAttr(initializer=initializer)
return bias_attr return bias_attr
...@@ -38,18 +38,14 @@ class Head(nn.Layer): ...@@ -38,18 +38,14 @@ class Head(nn.Layer):
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr(name=name_list[0] + '.w_0'), weight_attr=ParamAttr(),
bias_attr=False) bias_attr=False)
self.conv_bn1 = nn.BatchNorm( self.conv_bn1 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name_list[1] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)), initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=name_list[1] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[1] + '.w_1',
moving_variance_name=name_list[1] + '.w_2',
act='relu') act='relu')
self.conv2 = nn.Conv2DTranspose( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
...@@ -57,19 +53,14 @@ class Head(nn.Layer): ...@@ -57,19 +53,14 @@ class Head(nn.Layer):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[2] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) bias_attr=get_bias_attr(in_channels // 4))
self.conv_bn2 = nn.BatchNorm( self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name_list[3] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)), initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=name_list[3] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[3] + '.w_1',
moving_variance_name=name_list[3] + '.w_2',
act="relu") act="relu")
self.conv3 = nn.Conv2DTranspose( self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
...@@ -77,10 +68,8 @@ class Head(nn.Layer): ...@@ -77,10 +68,8 @@ class Head(nn.Layer):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[4] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), bias_attr=get_bias_attr(in_channels // 4), )
)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
...@@ -117,7 +106,7 @@ class DBHead(nn.Layer): ...@@ -117,7 +106,7 @@ class DBHead(nn.Layer):
def step_function(self, x, y): def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
def forward(self, x): def forward(self, x, targets=None):
shrink_maps = self.binarize(x) shrink_maps = self.binarize(x)
if not self.training: if not self.training:
return {'maps': shrink_maps} return {'maps': shrink_maps}
......
...@@ -109,7 +109,7 @@ class EASTHead(nn.Layer): ...@@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
act=None, act=None,
name="f_geo") name="f_geo")
def forward(self, x): def forward(self, x, targets=None):
f_det = self.det_conv1(x) f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det) f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det) f_score = self.score_conv(f_det)
......
...@@ -116,7 +116,7 @@ class SASTHead(nn.Layer): ...@@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
self.head1 = SAST_Header1(in_channels) self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels) self.head2 = SAST_Header2(in_channels)
def forward(self, x): def forward(self, x, targets=None):
f_score, f_border = self.head1(x) f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x) f_tvo, f_tco = self.head2(x)
......
...@@ -220,7 +220,7 @@ class PGHead(nn.Layer): ...@@ -220,7 +220,7 @@ class PGHead(nn.Layer):
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)), weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False) bias_attr=False)
def forward(self, x): def forward(self, x, targets=None):
f_score = self.conv_f_score1(x) f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score) f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score) f_score = self.conv_f_score3(f_score)
......
...@@ -23,32 +23,57 @@ from paddle import ParamAttr, nn ...@@ -23,32 +23,57 @@ from paddle import ParamAttr, nn
from paddle.nn import functional as F from paddle.nn import functional as F
def get_para_bias_attr(l2_decay, k, name): def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay) regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0) stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv) initializer = nn.initializer.Uniform(-stdv, stdv)
weight_attr = ParamAttr( weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
regularizer=regularizer, initializer=initializer, name=name + "_w_attr") bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [weight_attr, bias_attr] return [weight_attr, bias_attr]
class CTCHead(nn.Layer): class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): def __init__(self,
in_channels,
out_channels,
fc_decay=0.0004,
mid_channels=None,
**kwargs):
super(CTCHead, self).__init__() super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr( if mid_channels is None:
l2_decay=fc_decay, k=in_channels, name='ctc_fc') weight_attr, bias_attr = get_para_bias_attr(
self.fc = nn.Linear( l2_decay=fc_decay, k=in_channels)
in_channels, self.fc = nn.Linear(
out_channels, in_channels,
weight_attr=weight_attr, out_channels,
bias_attr=bias_attr, weight_attr=weight_attr,
name='ctc_fc') bias_attr=bias_attr)
else:
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc1 = nn.Linear(
in_channels,
mid_channels,
weight_attr=weight_attr1,
bias_attr=bias_attr1)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=mid_channels)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
weight_attr=weight_attr2,
bias_attr=bias_attr2)
self.out_channels = out_channels self.out_channels = out_channels
self.mid_channels = mid_channels
def forward(self, x, labels=None): def forward(self, x, targets=None):
predicts = self.fc(x) if self.mid_channels is None:
predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)
if not self.training: if not self.training:
predicts = F.softmax(predicts, axis=2) predicts = F.softmax(predicts, axis=2)
return predicts return predicts
...@@ -250,7 +250,8 @@ class SRNHead(nn.Layer): ...@@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
def forward(self, inputs, others): def forward(self, inputs, targets=None):
others = targets[-4:]
encoder_word_pos = others[0] encoder_word_pos = others[0]
gsrm_word_pos = others[1] gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2] gsrm_slf_attn_bias1 = others[2]
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len
if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length+1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
else:
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
structure_probs = None
loc_preds = None
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0
while i < max_elem_length+1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
...@@ -21,7 +21,8 @@ def build_neck(config): ...@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
from .pg_fpn import PGFPN from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] from .table_fpn import TableFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
...@@ -32,61 +32,53 @@ class DBFPN(nn.Layer): ...@@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
in_channels=in_channels[0], in_channels=in_channels[0],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_51.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in3_conv = nn.Conv2D( self.in3_conv = nn.Conv2D(
in_channels=in_channels[1], in_channels=in_channels[1],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_50.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in4_conv = nn.Conv2D( self.in4_conv = nn.Conv2D(
in_channels=in_channels[2], in_channels=in_channels[2],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_49.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in5_conv = nn.Conv2D( self.in5_conv = nn.Conv2D(
in_channels=in_channels[3], in_channels=in_channels[3],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_48.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p5_conv = nn.Conv2D( self.p5_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_52.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p4_conv = nn.Conv2D( self.p4_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_53.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p3_conv = nn.Conv2D( self.p3_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_54.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p2_conv = nn.Conv2D( self.p2_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_55.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
def forward(self, x): def forward(self, x):
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class TableFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(TableFPN, self).__init__()
self.out_channels = 512
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
stride = 1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4,
out_channels=512,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.in5_conv(c5)
in4 = self.in4_conv(c4)
in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2)
out4 = in4 + F.upsample(
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
out3 = in3 + F.upsample(
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
out2 = in2 + F.upsample(
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
fuse_conv = self.fuse_conv(fuse) * 0.005
return [c5 + fuse_conv]
...@@ -21,18 +21,20 @@ import copy ...@@ -21,18 +21,20 @@ import copy
__all__ = ['build_post_process'] __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
...@@ -294,14 +325,8 @@ class TableLabelDecode(object): ...@@ -294,14 +325,8 @@ class TableLabelDecode(object):
""" """ """ """
def __init__(self, def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path, character_dict_path,
**kwargs): **kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path) list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character) list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) list_elem = self.add_special_char(list_elem)
...@@ -338,18 +363,6 @@ class TableLabelDecode(object): ...@@ -338,18 +363,6 @@ class TableLabelDecode(object):
list_character = [self.beg_str] + list_character + [self.end_str] list_character = [self.beg_str] + list_character + [self.end_str]
return list_character return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds): def __call__(self, preds):
structure_probs = preds['structure_probs'] structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds'] loc_preds = preds['loc_preds']
......
...@@ -22,7 +22,7 @@ logger_initialized = {} ...@@ -22,7 +22,7 @@ logger_initialized = {}
@functools.lru_cache() @functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO): def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name. """Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will logger by adding one or two handlers, otherwise the initialized logger will
......
...@@ -23,6 +23,8 @@ import six ...@@ -23,6 +23,8 @@ import six
import paddle import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
...@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger): ...@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): def init_model(config, model, optimizer=None, lr_scheduler=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB
if weight_name in pre_state_dict.keys():
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key:
# delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[
weight_name].squeeze()
# change axis
if len(pre_state_dict[weight_name].shape) > 1:
pre_state_dict[weight_name] = pre_state_dict[
weight_name].transpose((1, 0))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_state_dict(param_state_dict)
return
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
logger = get_logger()
global_config = config['Global'] global_config = config['Global']
checkpoints = global_config.get('checkpoints') checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
...@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {}) best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict: if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
load_static_weights = global_config.get('load_static_weights', False)
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list): for pretrained in pretrained_model:
load_static_weights = [load_static_weights] * len(pretrained_model) if not (os.path.isdir(pretrained) or
for idx, pretrained in enumerate(pretrained_model): os.path.exists(pretrained + '.pdparams')):
load_static = load_static_weights[idx] raise ValueError("Model pretrain path {} does not "
load_dygraph_pretrain( "exists.".format(pretrained))
model, logger, path=pretrained, load_static_weights=load_static) param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format( logger.info("load pretrained model from {}".format(
pretrained_model)) pretrained_model))
else: else:
......
# TableStructurer
1. 代码使用
```python
import cv2
from paddlestructure import PaddleStructure,draw_result
table_engine = PaddleStructure(
output='./output/table',
show_log=True)
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
2. 命令行使用
```bash
paddlestructure --image_dir=../doc/table/1.png
```
# 表格结构和内容预测
先cd到PaddleOCR/ppstructure目录下
预测
```python
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
```
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
评估
```python
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment