Unverified Commit 465ef3bf authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into bm_dyg

parents bf9f93f7 bc999986
......@@ -23,10 +23,10 @@ import paddle.nn.functional as F
from paddle import ParamAttr
def get_bias_attr(k, name):
def get_bias_attr(k):
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr")
bias_attr = ParamAttr(initializer=initializer)
return bias_attr
......@@ -38,18 +38,14 @@ class Head(nn.Layer):
out_channels=in_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(name=name_list[0] + '.w_0'),
weight_attr=ParamAttr(),
bias_attr=False)
self.conv_bn1 = nn.BatchNorm(
num_channels=in_channels // 4,
param_attr=ParamAttr(
name=name_list[1] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr(
name=name_list[1] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[1] + '.w_1',
moving_variance_name=name_list[1] + '.w_2',
act='relu')
self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
......@@ -57,19 +53,14 @@ class Head(nn.Layer):
kernel_size=2,
stride=2,
weight_attr=ParamAttr(
name=name_list[2] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
bias_attr=get_bias_attr(in_channels // 4))
self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4,
param_attr=ParamAttr(
name=name_list[3] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr(
name=name_list[3] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[3] + '.w_1',
moving_variance_name=name_list[3] + '.w_2',
act="relu")
self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
......@@ -77,10 +68,8 @@ class Head(nn.Layer):
kernel_size=2,
stride=2,
weight_attr=ParamAttr(
name=name_list[4] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
)
bias_attr=get_bias_attr(in_channels // 4), )
def forward(self, x):
x = self.conv1(x)
......@@ -117,7 +106,7 @@ class DBHead(nn.Layer):
def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
def forward(self, x):
def forward(self, x, targets=None):
shrink_maps = self.binarize(x)
if not self.training:
return {'maps': shrink_maps}
......
......@@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
act=None,
name="f_geo")
def forward(self, x):
def forward(self, x, targets=None):
f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det)
......
......@@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels)
def forward(self, x):
def forward(self, x, targets=None):
f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x)
......
......@@ -220,7 +220,7 @@ class PGHead(nn.Layer):
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
def forward(self, x):
def forward(self, x, targets=None):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
......
......@@ -23,32 +23,57 @@ from paddle import ParamAttr, nn
from paddle.nn import functional as F
def get_para_bias_attr(l2_decay, k, name):
def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv)
weight_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
bias_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
return [weight_attr, bias_attr]
class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
def __init__(self,
in_channels,
out_channels,
fc_decay=0.0004,
mid_channels=None,
**kwargs):
super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr,
name='ctc_fc')
if mid_channels is None:
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr)
else:
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc1 = nn.Linear(
in_channels,
mid_channels,
weight_attr=weight_attr1,
bias_attr=bias_attr1)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=mid_channels)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
weight_attr=weight_attr2,
bias_attr=bias_attr2)
self.out_channels = out_channels
self.mid_channels = mid_channels
def forward(self, x, labels=None):
predicts = self.fc(x)
def forward(self, x, targets=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)
if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
......@@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
def forward(self, inputs, others):
def forward(self, inputs, targets=None):
others = targets[-4:]
encoder_word_pos = others[0]
gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2]
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len
if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length+1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
else:
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
structure_probs = None
loc_preds = None
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0
while i < max_elem_length+1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
......@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
from .table_fpn import TableFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
......@@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_51.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_50.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_49.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_48.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_52.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_53.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_54.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_55.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
def forward(self, x):
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class TableFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(TableFPN, self).__init__()
self.out_channels = 512
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
stride = 1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4,
out_channels=512,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.in5_conv(c5)
in4 = self.in4_conv(c4)
in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2)
out4 = in4 + F.upsample(
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
out3 = in3 + F.upsample(
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
out2 = in2 + F.upsample(
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
fuse_conv = self.fuse_conv(fuse) * 0.005
return [c5 + fuse_conv]
......@@ -21,18 +21,20 @@ import copy
__all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode'
]
config = copy.deepcopy(config)
......
......@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......@@ -294,14 +325,8 @@ class TableLabelDecode(object):
""" """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
......@@ -338,18 +363,6 @@ class TableLabelDecode(object):
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
......
......@@ -22,7 +22,7 @@ logger_initialized = {}
@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
......
......@@ -23,6 +23,8 @@ import six
import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
......@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB
if weight_name in pre_state_dict.keys():
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key:
# delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[
weight_name].squeeze()
# change axis
if len(pre_state_dict[weight_name].shape) > 1:
pre_state_dict[weight_name] = pre_state_dict[
weight_name].transpose((1, 0))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_state_dict(param_state_dict)
return
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
def init_model(config, model, optimizer=None, lr_scheduler=None):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
global_config = config['Global']
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
......@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_static_weights = global_config.get('load_static_weights', False)
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(pretrained_model)
for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx]
load_dygraph_pretrain(
model, logger, path=pretrained, load_static_weights=load_static)
for pretrained in pretrained_model:
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
else:
......
# TableStructurer
1. 代码使用
```python
import cv2
from paddlestructure import PaddleStructure,draw_result
table_engine = PaddleStructure(
output='./output/table',
show_log=True)
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
2. 命令行使用
```bash
paddlestructure --image_dir=../doc/table/1.png
```
# 表格结构和内容预测
先cd到PaddleOCR/ppstructure目录下
预测
```python
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
```
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
评估
```python
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment