Commit d73ed79c authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'Evezerest/dygraph' into dygraph

parents af77d08c 2945abd7
...@@ -30,21 +30,17 @@ class CenterLoss(nn.Layer): ...@@ -30,21 +30,17 @@ class CenterLoss(nn.Layer):
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
""" """
def __init__(self, def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.centers = paddle.randn( self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype("float64") shape=[self.num_classes, self.feat_dim]).astype("float64")
if init_center: if center_file_path is not None:
assert os.path.exists( assert os.path.exists(
center_file_path center_file_path
), f"center path({center_file_path}) must exist when init_center is set as True." ), f"center path({center_file_path}) must exist when it is not None."
with open(center_file_path, 'rb') as f: with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f) char_dict = pickle.load(f)
for key in char_dict.keys(): for key in char_dict.keys():
......
...@@ -16,7 +16,7 @@ __all__ = ["build_backbone"] ...@@ -16,7 +16,7 @@ __all__ = ["build_backbone"]
def build_backbone(config, model_type): def build_backbone(config, model_type):
if model_type == "det": if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST from .det_resnet_vd_sast import ResNet_SAST
...@@ -36,10 +36,6 @@ def build_backbone(config, model_type): ...@@ -36,10 +36,6 @@ def build_backbone(config, model_type):
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
support_dict = ["ResNet"] support_dict = ["ResNet"]
elif model_type == "table":
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -26,8 +26,10 @@ class MobileNetV3(nn.Layer): ...@@ -26,8 +26,10 @@ class MobileNetV3(nn.Layer):
scale=0.5, scale=0.5,
large_stride=None, large_stride=None,
small_stride=None, small_stride=None,
disable_se=False,
**kwargs): **kwargs):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if small_stride is None: if small_stride is None:
small_stride = [2, 2, 2, 2] small_stride = [2, 2, 2, 2]
if large_stride is None: if large_stride is None:
...@@ -101,6 +103,7 @@ class MobileNetV3(nn.Layer): ...@@ -101,6 +103,7 @@ class MobileNetV3(nn.Layer):
block_list = [] block_list = []
inplanes = make_divisible(inplanes * scale) inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg: for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
block_list.append( block_list.append(
ResidualUnit( ResidualUnit(
in_channels=inplanes, in_channels=inplanes,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
__all__ = ['MobileNetV3']
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class MobileNetV3(nn.Layer):
def __init__(self,
in_channels=3,
model_name='large',
scale=0.5,
disable_se=False,
**kwargs):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hardswish', 2],
[3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hardswish', 2],
[5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hardswish', 2],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', 2],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert scale in supported_scale, \
"supported scale are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act='hardswish',
name='conv1')
self.stages = []
self.out_channels = []
block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
start_idx = 2 if model_name == 'large' else 0
if s == 2 and i > start_idx:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act='hardswish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
def forward(self, x):
x = self.conv(x)
out_list = []
for stage in self.stages:
x = stage(x)
out_list.append(x)
return out_list
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=None,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hardswish":
x = F.hardswish(x)
else:
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
return x
class ResidualUnit(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
kernel_size,
stride,
use_se,
act=None,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act,
name=name + "_depthwise")
if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
return x
class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
out = []
for block in self.stages:
y = block(y)
out.append(y)
return out
...@@ -53,7 +53,6 @@ class AttentionHead(nn.Layer): ...@@ -53,7 +53,6 @@ class AttentionHead(nn.Layer):
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1) output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output) probs = self.generator(output)
else: else:
targets = paddle.zeros(shape=[batch_size], dtype="int32") targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None probs = None
...@@ -75,6 +74,7 @@ class AttentionHead(nn.Layer): ...@@ -75,6 +74,7 @@ class AttentionHead(nn.Layer):
probs_step, axis=1)], axis=1) probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1) next_input = probs_step.argmax(axis=1)
targets = next_input targets = next_input
if not self.training:
probs = paddle.nn.functional.softmax(probs, axis=2) probs = paddle.nn.functional.softmax(probs, axis=2)
return probs return probs
......
...@@ -53,7 +53,7 @@ def compute_partial_repr(input_points, control_points): ...@@ -53,7 +53,7 @@ def compute_partial_repr(input_points, control_points):
1] 1]
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist) repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0 # fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix mask = np.array(repr_matrix != repr_matrix)
repr_matrix[mask] = 0 repr_matrix[mask] = 0
return repr_matrix return repr_matrix
......
...@@ -29,6 +29,7 @@ class EASTPostProcess(object): ...@@ -29,6 +29,7 @@ class EASTPostProcess(object):
""" """
The post process for EAST. The post process for EAST.
""" """
def __init__(self, def __init__(self,
score_thresh=0.8, score_thresh=0.8,
cover_thresh=0.1, cover_thresh=0.1,
...@@ -38,11 +39,6 @@ class EASTPostProcess(object): ...@@ -38,11 +39,6 @@ class EASTPostProcess(object):
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.cover_thresh = cover_thresh self.cover_thresh = cover_thresh
self.nms_thresh = nms_thresh self.nms_thresh = nms_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def restore_rectangle_quad(self, origin, geometry): def restore_rectangle_quad(self, origin, geometry):
""" """
...@@ -64,6 +60,7 @@ class EASTPostProcess(object): ...@@ -64,6 +60,7 @@ class EASTPostProcess(object):
""" """
restore text boxes from score map and geo map restore text boxes from score map and geo map
""" """
score_map = score_map[0] score_map = score_map[0]
geo_map = np.swapaxes(geo_map, 1, 0) geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = np.swapaxes(geo_map, 1, 2) geo_map = np.swapaxes(geo_map, 1, 2)
...@@ -79,10 +76,14 @@ class EASTPostProcess(object): ...@@ -79,10 +76,14 @@ class EASTPostProcess(object):
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8)) boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
if self.is_python35:
try:
import lanms import lanms
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
else: except:
print(
'you should install lanms by pip3 install lanms-nova to speed up nms_locality'
)
boxes = nms_locality(boxes.astype(np.float64), nms_thresh) boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0: if boxes.shape[0] == 0:
return [] return []
...@@ -139,4 +140,4 @@ class EASTPostProcess(object): ...@@ -139,4 +140,4 @@ class EASTPostProcess(object):
continue continue
boxes_norm.append(box) boxes_norm.append(box)
dt_boxes_list.append({'points': np.array(boxes_norm)}) dt_boxes_list.append({'points': np.array(boxes_norm)})
return dt_boxes_list return dt_boxes_list
\ No newline at end of file
...@@ -54,14 +54,37 @@ def load_model(config, model, optimizer=None): ...@@ -54,14 +54,37 @@ def load_model(config, model, optimizer=None):
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if checkpoints: if checkpoints:
if checkpoints.endswith('pdparams'): if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '') checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdopt"), \ assert os.path.exists(checkpoints + ".pdparams"), \
f"The {checkpoints}.pdopt does not exists!" "The {}.pdparams does not exists!".format(checkpoints)
load_pretrained_params(model, checkpoints)
optim_dict = paddle.load(checkpoints + '.pdopt') # load params from trained model
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
continue
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(optim_dict) if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
if os.path.exists(checkpoints + '.states'): if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f: with open(checkpoints + '.states', 'rb') as f:
...@@ -80,10 +103,10 @@ def load_model(config, model, optimizer=None): ...@@ -80,10 +103,10 @@ def load_model(config, model, optimizer=None):
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
logger = get_logger() logger = get_logger()
if path.endswith('pdparams'): if path.endswith('.pdparams'):
path = path.replace('.pdparams', '') path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \ assert os.path.exists(path + ".pdparams"), \
f"The {path}.pdparams does not exists!" "The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
...@@ -92,11 +115,11 @@ def load_pretrained_params(model, path): ...@@ -92,11 +115,11 @@ def load_pretrained_params(model, path):
if list(state_dict[k1].shape) == list(params[k2].shape): if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2] new_state_dict[k1] = params[k2]
else: else:
logger.info( logger.warning(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" "The shape of model params {} {} not matched with loaded params {} {} !".
) format(k1, state_dict[k1].shape, k2, params[k2].shape))
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info(f"load pretrain successful from {path}") logger.info("load pretrain successful from {}".format(path))
return model return model
......
# 视觉问答(VQA)
VQA主要特性如下:
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对)
- 支持SER任务与OCR引擎联合的端到端系统预测与评估。
- 支持SER任务和RE任务的自定义训练
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
## 1. 效果演示
**注意:** 测试图片来源于XFUN数据集。
### 1.1 SER
<div align="center">
<img src="./images/result_ser/zh_val_0_ser.jpg" width = "600" />
</div>
<div align="center">
<img src="./images/result_ser/zh_val_42_ser.jpg" width = "600" />
</div>
其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
### 1.2 RE
* Coming soon!
## 2. 安装
### 2.1 安装依赖
- **(1) 安装PaddlePaddle**
```bash
pip3 install --upgrade pip
# GPU安装
python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
# CPU安装
python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash
pip install "paddleocr>=2.2" # 推荐使用2.2+版本
```
- **(2)下载VQA源码(预测+训练)**
```bash
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
- **(3)安装PaddleNLP**
```bash
# 需要使用PaddleNLP最新的代码版本进行安装
git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
cd PaddleNLP
pip install -e .
```
- **(4)安装VQA的`requirements`**
```bash
pip install -r requirements.txt
```
## 3. 使用
### 3.1 数据和预训练模型准备
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)
如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。
* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
* RE任务预训练模型下载链接:coming soon!
### 3.2 SER任务
* 启动训练
```shell
python train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--save_steps 500 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--seed 2048
```
最终会打印出`precision`, `recall`, `f1`等指标,如下所示。
```
best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
```
模型和训练日志会保存在`./output/ser/`文件夹中。
* 使用评估集合中提供的OCR识别结果进行预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--output_dir "output_res/" \
--infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
```
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`
* 使用`OCR引擎 + SER`串联结果
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--max_seq_length 512 \
--output_dir "output_res_e2e/"
```
*`OCR引擎 + SER`预测系统进行端到端评估
```shell
export CUDA_VISIBLE_DEVICES=0
python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
3.3 RE任务
coming soon!
## 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
- XFUND dataset, https://github.com/doc-analysis/XFUND
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import sys
# import Polygon
import shapely
from shapely.geometry import Polygon
import numpy as np
from collections import defaultdict
import operator
import editdistance
import argparse
import json
import copy
def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
# img/zh_val_0.jpg {
# "height": 3508,
# "width": 2480,
# "ocr_info": [
# {"text": "Maribyrnong", "label": "other", "bbox": [1958, 144, 2184, 198]},
# {"text": "CITYCOUNCIL", "label": "other", "bbox": [2052, 183, 2171, 214]},
# ]
assert fp_type in ["gt", "pred"]
key = "label" if fp_type == "gt" else "pred"
res_dict = dict()
with open(fp, "r") as fin:
lines = fin.readlines()
for _, line in enumerate(lines):
img_path, info = line.strip().split("\t")
# get key
image_name = os.path.basename(img_path)
res_dict[image_name] = []
# get infos
json_info = json.loads(info)
for single_ocr_info in json_info["ocr_info"]:
label = single_ocr_info[key].upper()
if label in ["O", "OTHERS", "OTHER"]:
label = "O"
if ignore_background and label == "O":
continue
single_ocr_info["label"] = label
res_dict[image_name].append(copy.deepcopy(single_ocr_info))
return res_dict
def polygon_from_str(polygon_points):
"""
Create a shapely polygon object from gt or dt line.
"""
polygon_points = np.array(polygon_points).reshape(4, 2)
polygon = Polygon(polygon_points).convex_hull
return polygon
def polygon_iou(poly1, poly2):
"""
Intersection over union between two shapely polygons.
"""
if not poly1.intersects(
poly2): # this test is fast and can accelerate calculation
iou = 0
else:
try:
inter_area = poly1.intersection(poly2).area
union_area = poly1.area + poly2.area - inter_area
iou = float(inter_area) / union_area
except shapely.geos.TopologicalError:
# except Exception as e:
# print(e)
print('shapely.geos.TopologicalError occured, iou set to 0')
iou = 0
return iou
def ed(args, str1, str2):
if args.ignore_space:
str1 = str1.replace(" ", "")
str2 = str2.replace(" ", "")
if args.ignore_case:
str1 = str1.lower()
str2 = str2.lower()
return editdistance.eval(str1, str2)
def convert_bbox_to_polygon(bbox):
"""
bbox : [x1, y1, x2, y2]
output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
"""
xmin, ymin, xmax, ymax = bbox
poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
return poly
def eval_e2e(args):
# gt
gt_results = parse_ser_results_fp(args.gt_json_path, "gt",
args.ignore_background)
# pred
dt_results = parse_ser_results_fp(args.pred_json_path, "pred",
args.ignore_background)
assert set(gt_results.keys()) == set(dt_results.keys())
iou_thresh = args.iou_thres
num_gt_chars = 0
gt_count = 0
dt_count = 0
hit = 0
ed_sum = 0
for img_name in gt_results:
gt_info = gt_results[img_name]
gt_count += len(gt_info)
dt_info = dt_results[img_name]
dt_count += len(dt_info)
dt_match = [False] * len(dt_info)
gt_match = [False] * len(gt_info)
all_ious = defaultdict(tuple)
# gt: {text, label, bbox or poly}
for index_gt, gt in enumerate(gt_info):
if "poly" not in gt:
gt["poly"] = convert_bbox_to_polygon(gt["bbox"])
gt_poly = polygon_from_str(gt["poly"])
for index_dt, dt in enumerate(dt_info):
if "poly" not in dt:
dt["poly"] = convert_bbox_to_polygon(dt["bbox"])
dt_poly = polygon_from_str(dt["poly"])
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
sorted_ious = sorted(
all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
# matched gt and dt
for gt_dt_pair in sorted_gt_dt_pairs:
index_gt, index_dt = gt_dt_pair
if gt_match[index_gt] == False and dt_match[index_dt] == False:
gt_match[index_gt] = True
dt_match[index_dt] = True
# ocr rec results
gt_text = gt_info[index_gt]["text"]
dt_text = dt_info[index_dt]["text"]
# ser results
gt_label = gt_info[index_gt]["label"]
dt_label = dt_info[index_dt]["pred"]
if True: # ignore_masks[index_gt] == '0':
ed_sum += ed(args, gt_text, dt_text)
num_gt_chars += len(gt_text)
if gt_text == dt_text:
if args.ignore_ser_prediction or gt_label == dt_label:
hit += 1
# unmatched dt
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_text = dt_info[tindex]["text"]
gt_text = ""
ed_sum += ed(args, dt_text, gt_text)
# unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
if gt_match_flag == False:
dt_text = ""
gt_text = gt_info[tindex]["text"]
ed_sum += ed(args, gt_text, dt_text)
num_gt_chars += len(gt_text)
eps = 1e-9
print("config: ", args)
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
precision = hit / (dt_count + eps)
recall = hit / (gt_count + eps)
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
avg_edit_dist_img = ed_sum / len(gt_results)
avg_edit_dist_field = ed_sum / (gt_count + eps)
character_acc = 1 - ed_sum / (num_gt_chars + eps)
print('character_acc: %.2f' % (character_acc * 100) + "%")
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
print('precision: %.2f' % (precision * 100) + "%")
print('recall: %.2f' % (recall * 100) + "%")
print('fmeasure: %.2f' % (fmeasure * 100) + "%")
return
def parse_args():
"""
"""
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument(
"--gt_json_path",
default=None,
type=str,
required=True, )
parser.add_argument(
"--pred_json_path",
default=None,
type=str,
required=True, )
parser.add_argument("--iou_thres", default=0.5, type=float)
parser.add_argument(
"--ignore_case",
default=False,
type=str2bool,
help="whether to do lower case for the strs")
parser.add_argument(
"--ignore_space",
default=True,
type=str2bool,
help="whether to ignore space")
parser.add_argument(
"--ignore_background",
default=True,
type=str2bool,
help="whether to ignore other label")
parser.add_argument(
"--ignore_ser_prediction",
default=False,
type=str2bool,
help="whether to ignore ocr pred results")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
eval_e2e(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
def transfer_xfun_data(json_path=None, output_file=None):
with open(json_path, "r") as fin:
lines = fin.readlines()
json_info = json.loads(lines[0])
documents = json_info["documents"]
label_info = {}
with open(output_file, "w") as fout:
for idx, document in enumerate(documents):
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]
label_info["height"] = img_info["height"]
label_info["width"] = img_info["width"]
label_info["ocr_info"] = []
for doc in document:
label_info["ocr_info"].append({
"text": doc["text"],
"label": doc["label"],
"bbox": doc["box"],
"id": doc["id"],
"linking": doc["linking"],
"words": doc["words"]
})
fout.write(image_path + "\t" + json.dumps(
label_info, ensure_ascii=False) + "\n")
print("===ok====")
transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import cv2
import numpy as np
from copy import deepcopy
import paddle
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
def pad_sentences(tokenizer,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding with larger size, reshape is carried out
max_seq_len = (
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [tokenizer.pad_token_id] * difference
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}"
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def split_page(encoded_inputs, max_seq_len=512):
"""
truncate is often used in training process
"""
for key in encoded_inputs:
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
else: # for bbox
encoded_inputs[key] = encoded_inputs[key].reshape(
[-1, max_seq_len, 4])
return encoded_inputs
def preprocess(
tokenizer,
ori_img,
ocr_info,
img_size=(224, 224),
pad_token_label_id=-100,
max_seq_len=512,
add_special_ids=False,
return_attention_mask=True, ):
ocr_info = deepcopy(ocr_info)
height = ori_img.shape[0]
width = ori_img.shape[1]
img = cv2.resize(ori_img,
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = info["bbox"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
}
encoded_inputs = pad_sentences(
tokenizer,
encoded_inputs,
max_seq_len=max_seq_len,
return_attention_mask=return_attention_mask)
encoded_inputs = split_page(encoded_inputs)
fake_bs = encoded_inputs["input_ids"].shape[0]
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
[fake_bs] + list(img.shape))
encoded_inputs["segment_offset_id"] = segment_offset_id
return encoded_inputs
def postprocess(attention_mask, preds, label_map_path):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds = np.argmax(preds, axis=2)
_, label_map = get_bio_label_maps(label_map_path)
preds_list = [[] for _ in range(preds.shape[0])]
# keep batch info
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
if attention_mask[i][j] == 1:
preds_list[i].append(label_map[preds[i][j]])
return preds_list
def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
preds_list):
# must ensure the preds_list is generated from the same image
preds = [p for pred in preds_list for p in pred]
label2id_map, _ = get_bio_label_maps(label_map_path)
for key in label2id_map:
if key.startswith("I-"):
label2id_map[key] = label2id_map["B" + key[1:]]
id2label_map = dict()
for key in label2id_map:
val = label2id_map[key]
if key == "O":
id2label_map[val] = key
if key.startswith("B-") or key.startswith("I-"):
id2label_map[val] = key[2:]
else:
id2label_map[val] = key
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = preds[start_id:end_id]
curr_pred = [label2id_map[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = id2label_map[pred_id]
return ocr_info
@paddle.no_grad()
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
# model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
model.eval()
# load ocr results json
ocr_results = dict()
with open(args.ocr_json_path, "r") as fin:
lines = fin.readlines()
for line in lines:
img_name, json_info = line.split("\t")
ocr_results[os.path.basename(img_name)] = json.loads(json_info)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
inputs = preprocess(
tokenizer=tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
outputs = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info(
args.label_map_path, ocr_info, inputs["segment_offset_id"],
preds)
fout.write(img_path + "\t" + json.dumps(
{
"ocr_info": ocr_info,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
cv2.imwrite(
os.path.join(args.output_dir, os.path.basename(img_path)),
img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def parse_ocr_info_for_ser(ocr_result):
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
@paddle.no_grad()
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
model.eval()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
label2id_map_for_draw[key] = label2id_map[key]
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
args.ocr_det_model_dir)
# loop for infer
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
ocr_result = ocr_engine.ocr(img_path, cls=False)
ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess(
tokenizer=tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
outputs = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
label2id_map_for_draw)
fout.write(img_path + "\t" + json.dumps(
{
"ocr_info": ocr_info,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
cv2.imwrite(
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_ser.jpg"), img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment