Commit 253b8453 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 7cad4817 bc999986
......@@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels)
def forward(self, x):
def forward(self, x, targets=None):
f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x)
......
......@@ -220,7 +220,7 @@ class PGHead(nn.Layer):
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
def forward(self, x):
def forward(self, x, targets=None):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
......
......@@ -23,32 +23,57 @@ from paddle import ParamAttr, nn
from paddle.nn import functional as F
def get_para_bias_attr(l2_decay, k, name):
def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv)
weight_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
bias_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
return [weight_attr, bias_attr]
class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
def __init__(self,
in_channels,
out_channels,
fc_decay=0.0004,
mid_channels=None,
**kwargs):
super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr,
name='ctc_fc')
if mid_channels is None:
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr)
else:
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels)
self.fc1 = nn.Linear(
in_channels,
mid_channels,
weight_attr=weight_attr1,
bias_attr=bias_attr1)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=mid_channels)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
weight_attr=weight_attr2,
bias_attr=bias_attr2)
self.out_channels = out_channels
self.mid_channels = mid_channels
def forward(self, x, labels=None):
predicts = self.fc(x)
def forward(self, x, targets=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
predicts = self.fc1(x)
predicts = self.fc2(predicts)
if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
......@@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
def forward(self, inputs, others):
def forward(self, inputs, targets=None):
others = targets[-4:]
encoder_word_pos = others[0]
gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2]
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len
if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length+1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
else:
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
structure_probs = None
loc_preds = None
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0
while i < max_elem_length+1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
if self.loc_type == 1:
loc_preds = self.loc_generator(output)
loc_preds = F.sigmoid(loc_preds)
else:
loc_fea = fea.transpose([0, 2, 1])
loc_fea = self.loc_fea_trans(loc_fea)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
......@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
from .table_fpn import TableFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
......@@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_51.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_50.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_49.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(
name='conv2d_48.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_52.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_53.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_54.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
name='conv2d_55.w_0', initializer=weight_attr),
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
def forward(self, x):
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class TableFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(TableFPN, self).__init__()
self.out_channels = 512
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
stride = 1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.fuse_conv = nn.Conv2D(
in_channels=self.out_channels * 4,
out_channels=512,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.in5_conv(c5)
in4 = self.in4_conv(c4)
in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2)
out4 = in4 + F.upsample(
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
out3 = in3 + F.upsample(
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
out2 = in2 + F.upsample(
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
fuse_conv = self.fuse_conv(fuse) * 0.005
return [c5 + fuse_conv]
......@@ -21,18 +21,20 @@ import copy
__all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode'
]
config = copy.deepcopy(config)
......
......@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
self.character_str.append(line)
if use_space_char:
self.character_str += " "
self.character_str.append(" ")
dict_character = list(self.character_str)
else:
......@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......@@ -288,3 +319,138 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelDecode(object):
""" """
def __init__(self,
character_dict_path,
**kwargs):
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
</overline>
α

$
ω
ψ
χ
(
υ
σ
,
ρ
ε
0
4
8
b
<
Ψ
Ω
D
3
Π
H
</strike>
L
Φ
Χ
θ
P
κ
λ
μ
T
ξ
X
β
γ
δ
\
ζ
η
`
d
<strike>
h
f
l
Θ
p
t
</sub>
x
Β
Γ
Δ
|
ǂ
ɛ
j
̧
̌
«
#
</b>
'
Ι
+
/
·
7
;
?
C
÷
G
K
<sup>
O
S
С
W
Α
[
_
c
z
g
<i>
o
<sub>
s
w
φ
ʹ
{
»
̆
e
ˆ
τ
ι
Ø
ß
×
˃
˂
"
i
&
π
*
æ
.
ø
Q
6
:
>
a
B
F
J
̄
N
R
V
<overline>
Z
^
¤
¥
§
<underline>
¢
£
­
Λ
©
n
r
°
±
v
<b>
k
~
̇
@
ł
®
!
</sup>
%
)
-
1
5
9
=
А
A
Σ
E
I
M
m
̨
</i>
U
Y
]
̸
2
̂
̀
́
̊
̈
q
u
ı
y
</underline>
̃
}
ν
This diff is collapsed.
......@@ -22,7 +22,7 @@ logger_initialized = {}
@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tarfile
import requests
from tqdm import tqdm
from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path):
logger = get_logger()
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def is_link(s):
return s is not None and s.startswith('http')
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
url = default_url
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
file_name = url.split('/')[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url
......@@ -23,6 +23,8 @@ import six
import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
......@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB
if weight_name in pre_state_dict.keys():
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key:
# delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[
weight_name].squeeze()
# change axis
if len(pre_state_dict[weight_name].shape) > 1:
pre_state_dict[weight_name] = pre_state_dict[
weight_name].transpose((1, 0))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_state_dict(param_state_dict)
return
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
def init_model(config, model, optimizer=None, lr_scheduler=None):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
global_config = config['Global']
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
......@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_static_weights = global_config.get('load_static_weights', False)
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(pretrained_model)
for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx]
load_dygraph_pretrain(
model, logger, path=pretrained, load_static_weights=load_static)
for pretrained in pretrained_model:
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
else:
......
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include ppstructure *.py
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .paddlestructure import PaddleStructure, draw_result, to_excel
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
# PaddleStructure
install layoutparser
```sh
wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
pip3 install layoutparser-0.0.0-py3-none-any.whl
```
## 1. Introduction to pipeline
PaddleStructure is a toolkit for complex layout text OCR, the process is as follows
![pipeline](../doc/table/pipeline.png)
In PaddleStructure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, and the OCR process will be carried out according to the category.
Currently layoutparser will output five categories:
1. Text
2. Title
3. Figure
4. List
5. Table
Types 1-4 follow the traditional OCR process, and 5 follow the Table OCR process.
## 2. LayoutParser
## 3. Table OCR
[doc](table/README.md)
## 4. Predictive by inference engine
Use the following commands to complete the inference
```python
python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel, and the excel file name will be the coordinates of the table in the image.
## 5. PaddleStructure whl package introduction
### 5.1 Use
5.1.1 Use by code
```python
import os
import cv2
from paddlestructure import PaddleStructure,draw_result,save_res
table_engine = PaddleStructure(show_log=True)
save_folder = './output/table'
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
5.1.2 Use by command line
```bash
paddlestructure --image_dir=../doc/table/1.png
```
### Parameter Description
Most of the parameters are consistent with the paddleocr whl package, see [whl package documentation](../doc/doc_ch/whl.md)
| Parameter | Description | Default |
|------------------------|------------------------------------------------------|------------------|
| output | The path where excel and recognition results are saved | ./output/table |
| structure_max_len | When the table structure model predicts, the long side of the image is resized | 488 |
| structure_model_dir | Table structure inference model path | None |
| structure_char_type | Dictionary path used by table structure model | ../ppocr/utils/dict/table_structure_dict.tx |
# PaddleStructure
安装layoutparser
```sh
wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
pip3 install layoutparser-0.0.0-py3-none-any.whl
```
## 1. pipeline介绍
PaddleStructure 是一个用于复杂板式文字OCR的工具包,流程如下
![pipeline](../doc/table/pipeline.png)
在PaddleStructure中,图片会先经由layoutparser进行版面分析,在版面分析中,会对图片里的区域进行分类,根据根据类别进行对于的ocr流程。
目前layoutparser会输出五个类别:
1. Text
2. Title
3. Figure
4. List
5. Table
1-4类走传统的OCR流程,5走表格的OCR流程。
## 2. LayoutParser
[文档](layout/README.md)
## 3. Table OCR
[文档](table/README_ch.md)
## 4. 预测引擎推理
使用如下命令即可完成预测引擎的推理
```python
python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片会output字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,excel文件名为表格在图片里的坐标。
## 5. PaddleStructure whl包介绍
### 5.1 使用
5.1.1 代码使用
```python
import os
import cv2
from paddlestructure import PaddleStructure,draw_result,save_res
table_engine = PaddleStructure(show_log=True)
save_folder = './output/table'
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
5.1.2 命令行使用
```bash
paddlestructure --image_dir=../doc/table/1.png
```
### 参数说明
大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
| 字段 | 说明 | 默认值 |
|------------------------|------------------------------------------------------|------------------|
| output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None |
| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
# 版面分析使用说明
* [1. 安装whl包](#安装whl包)
* [2. 使用](#使用)
* [3. 后处理](#后处理)
* [4. 指标](#指标)
* [5. 训练版面分析模型](#训练版面分析模型)
<a name="安装whl包"></a>
## 1. 安装whl包
```bash
wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
pip install -U layoutparser-0.0.0-py3-none-any.whl
```
<a name="使用"></a>
## 2. 使用
使用layoutparser识别给定文档的布局:
```python
import layoutparser as lp
image = cv2.imread("imags/paper-image.jpg")
image = image[..., ::-1]
# 加载模型
model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5,
label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},
enforce_cpu=False,
enable_mkldnn=True)
# 检测
layout = model.detect(image)
# 显示结果
lp.draw_box(image, layout, box_width=3, show_element_type=True)
```
下图展示了结果,不同颜色的检测框表示不同的类别,并通过`show_element_type`在框的左上角显示具体类别:
<div align="center">
<img src="../../doc/table/result_all.jpg" width = "600" />
</div>
`PaddleDetectionLayoutModel`函数参数说明如下:
| 参数 | 含义 | 默认值 | 备注 |
| :------------: | :-------------------------: | :---------: | :----------------------------------------------------------: |
| config_path | 模型配置路径 | None | 指定config_path会自动下载模型(仅第一次,之后模型存在,不会再下载) |
| model_path | 模型路径 | None | 本地模型路径,config_path和model_path必须设置一个,不能同时为None |
| threshold | 预测得分的阈值 | 0.5 | \ |
| input_shape | reshape之后图片尺寸 | [3,640,640] | \ |
| batch_size | 测试batch size | 1 | \ |
| label_map | 类别映射表 | None | 设置config_path时,可以为None,根据数据集名称自动获取label_map |
| enforce_cpu | 代码是否使用CPU运行 | False | 设置为False表示使用GPU,True表示强制使用CPU |
| enforce_mkldnn | CPU预测中是否开启MKLDNN加速 | True | \ |
| thread_num | 设置CPU线程数 | 10 | \ |
目前支持以下几种模型配置和label map,您可以通过修改 `--config_path``--label_map`使用这些模型,从而检测不同类型的内容:
| dataset | config_path | label_map |
| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- |
| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} |
| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} |
| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} |
* TableBank word和TableBank latex分别在word文档、latex文档数据集训练;
* 下载TableBank数据集同时包含word和latex。
<a name="后处理"></a>
## 3. 后处理
版面分析检测包含多个类别,如果只想获取指定类别(如"Text"类别)的检测框、可以使用下述代码:
```python
# 首先过滤特定文本类型的区域
text_blocks = lp.Layout([b for b in layout if b.type=='Text'])
figure_blocks = lp.Layout([b for b in layout if b.type=='Figure'])
# 因为在图像区域内可能检测到文本区域,所以只需要删除它们
text_blocks = lp.Layout([b for b in text_blocks \
if not any(b.is_in(b_fig) for b_fig in figure_blocks)])
# 对文本区域排序并分配id
h, w = image.shape[:2]
left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)
left_blocks = text_blocks.filter_by(left_interval, center=True)
left_blocks.sort(key = lambda b:b.coordinates[1])
right_blocks = [b for b in text_blocks if b not in left_blocks]
right_blocks.sort(key = lambda b:b.coordinates[1])
# 最终合并两个列表,并按顺序添加索引
text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])
# 显示结果
lp.draw_box(image, text_blocks,
box_width=3,
show_element_id=True)
```
显示只有"Text"类别的结果:
<div align="center">
<img src="../../doc/table/result_text.jpg" width = "600" />
</div>
<a name="指标"></a>
## 4. 指标
| Dataset | mAP | CPU time cost | GPU time cost |
| --------- | ---- | ------------- | ------------- |
| PubLayNet | 93.6 | 1713.7ms | 66.6ms |
| TableBank | 96.2 | 1968.4ms | 65.1ms |
**Envrionment:**
**CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core
**GPU:** a single NVIDIA Tesla P40
<a name="训练版面分析模型"></a>
## 5. 训练版面分析模型
上述模型基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection) 训练,如果您想训练自己的版面分析模型,请参考:[train_layoutparser_model](train_layoutparser_model.md)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment