Commit 03bb378f authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix TRT8 core bug

parents a2a12fe4 2e9abcb9
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -53,7 +53,6 @@ class AttentionHead(nn.Layer): ...@@ -53,7 +53,6 @@ class AttentionHead(nn.Layer):
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1) output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output) probs = self.generator(output)
else: else:
targets = paddle.zeros(shape=[batch_size], dtype="int32") targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None probs = None
...@@ -75,7 +74,8 @@ class AttentionHead(nn.Layer): ...@@ -75,7 +74,8 @@ class AttentionHead(nn.Layer):
probs_step, axis=1)], axis=1) probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1) next_input = probs_step.argmax(axis=1)
targets = next_input targets = next_input
if not self.training:
probs = paddle.nn.functional.softmax(probs, axis=2)
return probs return probs
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -275,7 +294,6 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -275,7 +294,6 @@ class ParallelSARDecoder(BaseDecoder):
if img_metas is not None and self.mask: if img_metas is not None and self.mask:
valid_ratios = img_metas[-1] valid_ratios = img_metas[-1]
label = label.cuda()
lab_embedding = self.embedding(label) lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim # bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1) out_enc = out_enc.unsqueeze(1)
......
...@@ -23,32 +23,40 @@ import numpy as np ...@@ -23,32 +23,40 @@ import numpy as np
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): def __init__(self,
in_channels,
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
**kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.elem_num = 30
self.max_text_length = 100 self.max_text_length = max_text_length
self.max_elem_length = 500 self.max_elem_length = max_elem_length
self.max_cell_num = 500 self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type self.loc_type = loc_type
self.in_max_len = in_max_len self.in_max_len = in_max_len
if self.loc_type == 1: if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_generator = nn.Linear(hidden_size, 4)
else: else:
if self.in_max_len == 640: if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
elif self.in_max_len == 800: elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
else: else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot return input_ont_hot
...@@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer): ...@@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer):
if len(fea.shape) == 3: if len(fea.shape) == 3:
pass pass
else: else:
last_shape = int(np.prod(fea.shape[2:])) # gry added last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0] batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = [] output_hiddens = []
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_elem_length+1): for i in range(self.max_elem_length + 1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer): ...@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
alpha = None alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length) max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0 i = 0
while i < max_elem_length+1: while i < max_elem_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num) temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer): ...@@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer):
structure_probs_step = self.structure_generator(outputs) structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1 i += 1
output = paddle.concat(output_hiddens, axis=1) output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output) structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs) structure_probs = F.softmax(structure_probs)
...@@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer): ...@@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer):
loc_concat = paddle.concat([output, loc_fea], axis=2) loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat) loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds) loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer): class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__() super(AttentionGRUCell, self).__init__()
......
...@@ -11,64 +11,102 @@ ...@@ -11,64 +11,102 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/neck/fpn.py
"""
import paddle.nn as nn import paddle.nn as nn
import paddle import paddle
import math import math
import paddle.nn.functional as F import paddle.nn.functional as F
class Conv_BN_ReLU(nn.Layer): class Conv_BN_ReLU(nn.Layer):
def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0): def __init__(self,
in_planes,
out_planes,
kernel_size=1,
stride=1,
padding=0):
super(Conv_BN_ReLU, self).__init__() super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, self.conv = nn.Conv2D(
bias_attr=False) in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias_attr=False)
self.bn = nn.BatchNorm2D(out_planes, momentum=0.1) self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
self.relu = nn.ReLU() self.relu = nn.ReLU()
for m in self.sublayers(): for m in self.sublayers():
if isinstance(m, nn.Conv2D): if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Normal(0, math.sqrt(2. / n))) m.weight = paddle.create_parameter(
shape=m.weight.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Normal(
0, math.sqrt(2. / n)))
elif isinstance(m, nn.BatchNorm2D): elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(1.0)) m.weight = paddle.create_parameter(
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0)) shape=m.weight.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Constant(1.0))
m.bias = paddle.create_parameter(
shape=m.bias.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0))
def forward(self, x): def forward(self, x):
return self.relu(self.bn(self.conv(x))) return self.relu(self.bn(self.conv(x)))
class FPN(nn.Layer): class FPN(nn.Layer):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super(FPN, self).__init__() super(FPN, self).__init__()
# Top layer # Top layer
self.toplayer_ = Conv_BN_ReLU(in_channels[3], out_channels, kernel_size=1, stride=1, padding=0) self.toplayer_ = Conv_BN_ReLU(
in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
# Lateral layers # Lateral layers
self.latlayer1_ = Conv_BN_ReLU(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0) self.latlayer1_ = Conv_BN_ReLU(
in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer2_ = Conv_BN_ReLU(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0) self.latlayer2_ = Conv_BN_ReLU(
in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
self.latlayer3_ = Conv_BN_ReLU(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0) self.latlayer3_ = Conv_BN_ReLU(
in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
# Smooth layers # Smooth layers
self.smooth1_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.smooth1_ = Conv_BN_ReLU(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth2_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth3_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.smooth2_ = Conv_BN_ReLU(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.smooth3_ = Conv_BN_ReLU(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.out_channels = out_channels * 4 self.out_channels = out_channels * 4
for m in self.sublayers(): for m in self.sublayers():
if isinstance(m, nn.Conv2D): if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', m.weight = paddle.create_parameter(
default_initializer=paddle.nn.initializer.Normal(0, shape=m.weight.shape,
math.sqrt(2. / n))) dtype='float32',
default_initializer=paddle.nn.initializer.Normal(
0, math.sqrt(2. / n)))
elif isinstance(m, nn.BatchNorm2D): elif isinstance(m, nn.BatchNorm2D):
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', m.weight = paddle.create_parameter(
default_initializer=paddle.nn.initializer.Constant(1.0)) shape=m.weight.shape,
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0)) default_initializer=paddle.nn.initializer.Constant(1.0))
m.bias = paddle.create_parameter(
shape=m.bias.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0))
def _upsample(self, x, scale=1): def _upsample(self, x, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear') return F.upsample(x, scale_factor=scale, mode='bilinear')
...@@ -81,15 +119,15 @@ class FPN(nn.Layer): ...@@ -81,15 +119,15 @@ class FPN(nn.Layer):
p5 = self.toplayer_(f5) p5 = self.toplayer_(f5)
f4 = self.latlayer1_(f4) f4 = self.latlayer1_(f4)
p4 = self._upsample_add(p5, f4,2) p4 = self._upsample_add(p5, f4, 2)
p4 = self.smooth1_(p4) p4 = self.smooth1_(p4)
f3 = self.latlayer2_(f3) f3 = self.latlayer2_(f3)
p3 = self._upsample_add(p4, f3,2) p3 = self._upsample_add(p4, f3, 2)
p3 = self.smooth2_(p3) p3 = self.smooth2_(p3)
f2 = self.latlayer3_(f2) f2 = self.latlayer3_(f2)
p2 = self._upsample_add(p3, f2,2) p2 = self._upsample_add(p3, f2, 2)
p2 = self.smooth3_(p2) p2 = self.smooth3_(p2)
p3 = self._upsample(p3, 2) p3 = self._upsample(p3, 2)
...@@ -97,4 +135,4 @@ class FPN(nn.Layer): ...@@ -97,4 +135,4 @@ class FPN(nn.Layer):
p5 = self._upsample(p5, 8) p5 = self._upsample(p5, 8)
fuse = paddle.concat([p2, p3, p4, p5], axis=1) fuse = paddle.concat([p2, p3, p4, p5], axis=1)
return fuse return fuse
\ No newline at end of file
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/stn_head.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/tps_spatial_transformer.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
import platform
__all__ = ['build_post_process'] __all__ = ['build_post_process']
...@@ -26,21 +25,24 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess ...@@ -26,21 +25,24 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .pse_postprocess import PSEPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode' 'SEEDLabelDecode'
] ]
if config['name'] == 'PSEPostProcess':
from .pse_postprocess import PSEPostProcess
support_dict.append('PSEPostProcess')
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
if global_config is not None: if global_config is not None:
......
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refered from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -190,7 +193,8 @@ class DBPostProcess(object): ...@@ -190,7 +193,8 @@ class DBPostProcess(object):
class DistillationDBPostProcess(object): class DistillationDBPostProcess(object):
def __init__(self, model_name=["student"], def __init__(self,
model_name=["student"],
key=None, key=None,
thresh=0.3, thresh=0.3,
box_thresh=0.6, box_thresh=0.6,
...@@ -201,12 +205,13 @@ class DistillationDBPostProcess(object): ...@@ -201,12 +205,13 @@ class DistillationDBPostProcess(object):
**kwargs): **kwargs):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
self.post_process = DBPostProcess(thresh=thresh, self.post_process = DBPostProcess(
box_thresh=box_thresh, thresh=thresh,
max_candidates=max_candidates, box_thresh=box_thresh,
unclip_ratio=unclip_ratio, max_candidates=max_candidates,
use_dilation=use_dilation, unclip_ratio=unclip_ratio,
score_mode=score_mode) use_dilation=use_dilation,
score_mode=score_mode)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from .locality_aware_nms import nms_locality from .locality_aware_nms import nms_locality
import cv2 import cv2
import paddle import paddle
import lanms
import os import os
import sys import sys
...@@ -29,6 +30,7 @@ class EASTPostProcess(object): ...@@ -29,6 +30,7 @@ class EASTPostProcess(object):
""" """
The post process for EAST. The post process for EAST.
""" """
def __init__(self, def __init__(self,
score_thresh=0.8, score_thresh=0.8,
cover_thresh=0.1, cover_thresh=0.1,
...@@ -38,11 +40,6 @@ class EASTPostProcess(object): ...@@ -38,11 +40,6 @@ class EASTPostProcess(object):
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.cover_thresh = cover_thresh self.cover_thresh = cover_thresh
self.nms_thresh = nms_thresh self.nms_thresh = nms_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def restore_rectangle_quad(self, origin, geometry): def restore_rectangle_quad(self, origin, geometry):
""" """
...@@ -79,11 +76,8 @@ class EASTPostProcess(object): ...@@ -79,11 +76,8 @@ class EASTPostProcess(object):
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8)) boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
if self.is_python35: boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
import lanms # boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
else:
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0: if boxes.shape[0] == 0:
return [] return []
# Here we filter some low score boxes by the average score map, # Here we filter some low score boxes by the average score map,
...@@ -139,4 +133,4 @@ class EASTPostProcess(object): ...@@ -139,4 +133,4 @@ class EASTPostProcess(object):
continue continue
boxes_norm.append(box) boxes_norm.append(box)
dt_boxes_list.append({'points': np.array(boxes_norm)}) dt_boxes_list.append({'points': np.array(boxes_norm)})
return dt_boxes_list return dt_boxes_list
\ No newline at end of file
""" """
Locality aware nms. Locality aware nms.
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
""" """
import numpy as np import numpy as np
......
## 编译 ## 编译
code from https://github.com/whai362/pan_pp.pytorch This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/post_processing/pse
```python ```python
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
``` ```
...@@ -21,8 +21,9 @@ ori_path = os.getcwd() ...@@ -21,8 +21,9 @@ ori_path = os.getcwd()
os.chdir('ppocr/postprocess/pse_postprocess/pse') os.chdir('ppocr/postprocess/pse_postprocess/pse')
if subprocess.call( if subprocess.call(
'{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0: '{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0:
raise RuntimeError('Cannot compile pse: {}'.format( raise RuntimeError(
os.path.dirname(os.path.realpath(__file__)))) 'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'.
format(os.path.dirname(os.path.realpath(__file__))))
os.chdir(ori_path) os.chdir(ori_path)
from .pse import pse from .pse import pse
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -47,7 +51,8 @@ class PSEPostProcess(object): ...@@ -47,7 +51,8 @@ class PSEPostProcess(object):
pred = outs_dict['maps'] pred = outs_dict['maps']
if not isinstance(pred, paddle.Tensor): if not isinstance(pred, paddle.Tensor):
pred = paddle.to_tensor(pred) pred = paddle.to_tensor(pred)
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear') pred = F.interpolate(
pred, scale_factor=4 // self.scale, mode='bilinear')
score = F.sigmoid(pred[:, 0, :, :]) score = F.sigmoid(pred[:, 0, :, :])
...@@ -60,7 +65,9 @@ class PSEPostProcess(object): ...@@ -60,7 +65,9 @@ class PSEPostProcess(object):
boxes_batch = [] boxes_batch = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index]) boxes, scores = self.boxes_from_bitmap(score[batch_index],
kernels[batch_index],
shape_list[batch_index])
boxes_batch.append({'points': boxes, 'scores': scores}) boxes_batch.append({'points': boxes, 'scores': scores})
return boxes_batch return boxes_batch
...@@ -98,15 +105,14 @@ class PSEPostProcess(object): ...@@ -98,15 +105,14 @@ class PSEPostProcess(object):
mask = np.zeros((box_height, box_width), np.uint8) mask = np.zeros((box_height, box_width), np.uint8)
mask[points[:, 1], points[:, 0]] = 255 mask[points[:, 1], points[:, 0]] = 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)
bbox = np.squeeze(contours[0], 1) bbox = np.squeeze(contours[0], 1)
else: else:
raise NotImplementedError raise NotImplementedError
bbox[:, 0] = np.clip( bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
np.round(bbox[:, 0] / ratio_w), 0, src_w) bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
bbox[:, 1] = np.clip(
np.round(bbox[:, 1] / ratio_h), 0, src_h)
boxes.append(bbox) boxes.append(bbox)
scores.append(score_i) scores.append(score_i)
return boxes, scores return boxes, scores
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,18 +11,23 @@ ...@@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/loss/iou.py
"""
import paddle import paddle
EPS = 1e-6 EPS = 1e-6
def iou_single(a, b, mask, n_class): def iou_single(a, b, mask, n_class):
valid = mask == 1 valid = mask == 1
a = a.masked_select(valid) a = a.masked_select(valid)
b = b.masked_select(valid) b = b.masked_select(valid)
miou = [] miou = []
for i in range(n_class): for i in range(n_class):
if a.shape == [0] and a.shape==b.shape: if a.shape == [0] and a.shape == b.shape:
inter = paddle.to_tensor(0.0) inter = paddle.to_tensor(0.0)
union = paddle.to_tensor(0.0) union = paddle.to_tensor(0.0)
else: else:
...@@ -32,6 +37,7 @@ def iou_single(a, b, mask, n_class): ...@@ -32,6 +37,7 @@ def iou_single(a, b, mask, n_class):
miou = sum(miou) / len(miou) miou = sum(miou) / len(miou)
return miou return miou
def iou(a, b, mask, n_class=2, reduce=True): def iou(a, b, mask, n_class=2, reduce=True):
batch_size = a.shape[0] batch_size = a.shape[0]
...@@ -39,10 +45,10 @@ def iou(a, b, mask, n_class=2, reduce=True): ...@@ -39,10 +45,10 @@ def iou(a, b, mask, n_class=2, reduce=True):
b = b.reshape([batch_size, -1]) b = b.reshape([batch_size, -1])
mask = mask.reshape([batch_size, -1]) mask = mask.reshape([batch_size, -1])
iou = paddle.zeros((batch_size,), dtype='float32') iou = paddle.zeros((batch_size, ), dtype='float32')
for i in range(batch_size): for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class) iou[i] = iou_single(a[i], b[i], mask[i], n_class)
if reduce: if reduce:
iou = paddle.mean(iou) iou = paddle.mean(iou)
return iou return iou
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/PytorchOCR/blob/master/torchocr/utils/logging.py
"""
import os import os
import sys import sys
......
...@@ -24,15 +24,17 @@ from ppocr.utils.logging import get_logger ...@@ -24,15 +24,17 @@ from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path): def download_with_progressbar(url, save_path):
logger = get_logger() logger = get_logger()
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0)) if response.status_code == 200:
block_size = 1024 # 1 Kibibyte total_size_in_bytes = int(response.headers.get('content-length', 1))
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) block_size = 1024 # 1 Kibibyte
with open(save_path, 'wb') as file: progress_bar = tqdm(
for data in response.iter_content(block_size): total=total_size_in_bytes, unit='iB', unit_scale=True)
progress_bar.update(len(data)) with open(save_path, 'wb') as file:
file.write(data) for data in response.iter_content(block_size):
progress_bar.close() progress_bar.update(len(data))
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: file.write(data)
progress_bar.close()
else:
logger.error("Something went wrong while downloading models") logger.error("Something went wrong while downloading models")
sys.exit(0) sys.exit(0)
...@@ -45,7 +47,7 @@ def maybe_download(model_storage_directory, url): ...@@ -45,7 +47,7 @@ def maybe_download(model_storage_directory, url):
if not os.path.exists( if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams') os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists( ) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')): os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package' assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path)) print('download {} to {}'.format(url, tmp_path))
......
...@@ -25,7 +25,7 @@ import paddle ...@@ -25,7 +25,7 @@ import paddle
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_params'] __all__ = ['load_model']
def _mkdir_if_not_exist(path, logger): def _mkdir_if_not_exist(path, logger):
...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger): ...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def init_model(config, model, optimizer=None, lr_scheduler=None): def load_model(config, model, optimizer=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
...@@ -54,15 +54,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): ...@@ -54,15 +54,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if checkpoints: if checkpoints:
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \ assert os.path.exists(checkpoints + ".pdparams"), \
"Given dir {}.pdparams not exist.".format(checkpoints) "The {}.pdparams does not exists!".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints) # load params from trained model
para_dict = paddle.load(checkpoints + '.pdparams') params = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt') state_dict = model.state_dict()
model.set_state_dict(para_dict) new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
optim_dict = paddle.load(checkpoints + '.pdopt')
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(opti_dict) optimizer.set_state_dict(optim_dict)
if os.path.exists(checkpoints + '.states'): if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f: with open(checkpoints + '.states', 'rb') as f:
...@@ -73,70 +89,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): ...@@ -73,70 +89,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
if not isinstance(pretrained_model, list): load_pretrained_params(model, pretrained_model)
pretrained_model = [pretrained_model]
for pretrained in pretrained_model:
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
else: else:
logger.info('train from scratch') logger.info('train from scratch')
return best_model_dict return best_model_dict
def load_dygraph_params(config, model, logger, optimizer):
ckp = config['Global']['checkpoints']
if ckp and os.path.exists(ckp + ".pdparams"):
pre_best_model_dict = init_model(config, model, optimizer)
return pre_best_model_dict
else:
pm = config['Global']['pretrained_model']
if pm is None:
return {}
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
logger.info(f"The pretrained_model {pm} does not exists!")
return {}
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
if path is None: logger = get_logger()
return False if path.endswith('.pdparams'):
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"): path = path.replace('.pdparams', '')
print(f"The pretrained_model {path} does not exists!") assert os.path.exists(path + ".pdparams"), \
return False "The {}.pdparams does not exists!".format(path)
path = path if path.endswith('.pdparams') else path + '.pdparams' params = paddle.load(path + '.pdparams')
params = paddle.load(path)
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()): for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape): if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2] new_state_dict[k1] = params[k2]
else: else:
print( logger.warning(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" "The shape of model params {} {} not matched with loaded params {} {} !".
) format(k1, state_dict[k1].shape, k2, params[k2].shape))
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}") logger.info("load pretrain successful from {}".format(path))
return model return model
......
shapely shapely
scikit-image==0.18.3 scikit-image
imgaug==0.4.0 imgaug==0.4.0
pyclipper pyclipper
lmdb lmdb
...@@ -12,4 +12,5 @@ cython ...@@ -12,4 +12,5 @@ cython
lxml lxml
premailer premailer
openpyxl openpyxl
fasttext==0.9.1 fasttext==0.9.1
\ No newline at end of file lanms-nova
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment