Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
16c247ac
Commit
16c247ac
authored
Jun 21, 2021
by
MissPenguin
Browse files
refine
parent
7c8b2c8d
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
40 additions
and
101 deletions
+40
-101
configs/table/table_mv3.yml
configs/table/table_mv3.yml
+12
-12
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+0
-20
ppocr/data/pubtab_dataset.py
ppocr/data/pubtab_dataset.py
+2
-20
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+1
-1
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+10
-12
ppocr/modeling/necks/table_fpn.py
ppocr/modeling/necks/table_fpn.py
+10
-19
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+0
-12
tools/export_model.py
tools/export_model.py
+2
-1
tools/infer_table.py
tools/infer_table.py
+1
-3
tools/program.py
tools/program.py
+2
-1
No files found.
configs/table/table_mv3.yml
View file @
16c247ac
Global
:
Global
:
use_gpu
:
true
use_gpu
:
true
epoch_num
:
4
0
epoch_num
:
5
0
log_smooth_window
:
20
log_smooth_window
:
20
print_batch_step
:
5
print_batch_step
:
5
save_model_dir
:
./output/table_mv3/
save_model_dir
:
./output/table_mv3/
save_epoch_step
:
3
save_epoch_step
:
5
# evaluation is run every
50
00 iterations after the
400
0th iteration
# evaluation is run every
4
00 iterations after the 0th iteration
eval_batch_step
:
[
0
,
400
]
eval_batch_step
:
[
0
,
400
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
cal_metric_during_train
:
True
pretrained_model
:
pretrained_model
:
checkpoints
:
checkpoints
:
...
@@ -18,19 +17,20 @@ Global:
...
@@ -18,19 +17,20 @@ Global:
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
character_type
:
en
max_text_length
:
100
max_text_length
:
100
max_elem_length
:
8
00
max_elem_length
:
5
00
max_cell_num
:
500
max_cell_num
:
500
infer_mode
:
False
infer_mode
:
False
process_total_num
:
0
process_total_num
:
0
process_cut_num
:
0
process_cut_num
:
0
Optimizer
:
Optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.9
beta1
:
0.9
beta2
:
0.999
beta2
:
0.999
clip_norm
:
5.0
clip_norm
:
5.0
lr
:
lr
:
learning_rate
:
0.00
0
1
learning_rate
:
0.001
regularizer
:
regularizer
:
name
:
'
L2'
name
:
'
L2'
factor
:
0.00000
factor
:
0.00000
...
@@ -41,12 +41,12 @@ Architecture:
...
@@ -41,12 +41,12 @@ Architecture:
Backbone
:
Backbone
:
name
:
MobileNetV3
name
:
MobileNetV3
scale
:
1.0
scale
:
1.0
model_name
:
large
model_name
:
small
disable_se
:
True
Head
:
Head
:
name
:
TableAttentionHead
# AttentionHead
name
:
TableAttentionHead
hidden_size
:
256
#
hidden_size
:
256
l2_decay
:
0.00001
l2_decay
:
0.00001
# loc_type: 1
loc_type
:
2
loc_type
:
2
Loss
:
Loss
:
...
@@ -86,7 +86,7 @@ Train:
...
@@ -86,7 +86,7 @@ Train:
shuffle
:
True
shuffle
:
True
batch_size_per_card
:
32
batch_size_per_card
:
32
drop_last
:
True
drop_last
:
True
num_workers
:
4
num_workers
:
1
Eval
:
Eval
:
dataset
:
dataset
:
...
@@ -113,4 +113,4 @@ Eval:
...
@@ -113,4 +113,4 @@ Eval:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
batch_size_per_card
:
16
batch_size_per_card
:
16
num_workers
:
4
num_workers
:
1
ppocr/data/imaug/label_ops.py
View file @
16c247ac
...
@@ -412,7 +412,6 @@ class TableLabelEncode(object):
...
@@ -412,7 +412,6 @@ class TableLabelEncode(object):
return
None
return
None
elem_num
=
len
(
structure
)
elem_num
=
len
(
structure
)
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
# structure = [0] + structure + [0]
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
np
.
array
(
structure
)
structure
=
np
.
array
(
structure
)
data
[
'structure'
]
=
structure
data
[
'structure'
]
=
structure
...
@@ -443,8 +442,6 @@ class TableLabelEncode(object):
...
@@ -443,8 +442,6 @@ class TableLabelEncode(object):
if
cand_span_idx
<
(
self
.
max_elem_length
+
2
):
if
cand_span_idx
<
(
self
.
max_elem_length
+
2
):
if
structure
[
cand_span_idx
]
in
span_idx_list
:
if
structure
[
cand_span_idx
]
in
span_idx_list
:
structure_mask
[
cand_span_idx
]
=
span_weight
structure_mask
[
cand_span_idx
]
=
span_weight
# structure_mask[td_idx] = self.span_weight
# structure_mask[cand_span_idx] = self.span_weight
data
[
'bbox_list'
]
=
bbox_list
data
[
'bbox_list'
]
=
bbox_list
data
[
'bbox_list_mask'
]
=
bbox_list_mask
data
[
'bbox_list_mask'
]
=
bbox_list_mask
...
@@ -458,23 +455,6 @@ class TableLabelEncode(object):
...
@@ -458,23 +455,6 @@ class TableLabelEncode(object):
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
return
data
return
data
########
# for char decode
# cell_list = []
# for cell in cells:
# char_list = cell['tokens']
# cell = self.encode(char_list, 'char')
# if cell is None:
# return None
# cell = [0] + cell + [len(self.dict_character) - 1]
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
# cell_list.append(cell)
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
# cell_list = np.array(cell_list)
# cell_list_padding[0:cell_list.shape[0]] = cell_list
# data['cells'] = cell_list_padding
# return data
def
encode
(
self
,
text
,
char_or_elem
):
def
encode
(
self
,
text
,
char_or_elem
):
"""convert text-label into text-index.
"""convert text-label into text-index.
"""
"""
...
...
ppocr/data/pubtab_dataset.py
View file @
16c247ac
# copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -19,6 +19,7 @@ import json
...
@@ -19,6 +19,7 @@ import json
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
class
PubTabDataSet
(
Dataset
):
class
PubTabDataSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
super
(
PubTabDataSet
,
self
).
__init__
()
super
(
PubTabDataSet
,
self
).
__init__
()
...
@@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
...
@@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
def
load_hard_select_prob
(
self
):
label_path
=
"./pretrained_model/teds_score_exp5_st2_train.txt"
img_select_prob
=
{}
with
open
(
label_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
lno
in
range
(
len
(
lines
)):
substr
=
lines
[
lno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
split
(
" "
)
img_name
=
substr
[
0
].
strip
(
":"
)
score
=
float
(
substr
[
1
])
if
score
<=
0.8
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
0
]
elif
score
<=
0.98
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
1
]
else
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
2
]
return
img_select_prob
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
try
:
try
:
data_line
=
self
.
data_lines
[
idx
]
data_line
=
self
.
data_lines
[
idx
]
...
@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
...
@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
table_type
=
"simple"
table_type
=
"simple"
if
'colspan'
in
structure_str
or
'rowspan'
in
structure_str
:
if
'colspan'
in
structure_str
or
'rowspan'
in
structure_str
:
table_type
=
"complex"
table_type
=
"complex"
# if self.table_select_type != table_type:
# select_flag = False
if
table_type
==
"complex"
:
if
table_type
==
"complex"
:
if
self
.
table_select_prob
<
random
.
uniform
(
0
,
1
):
if
self
.
table_select_prob
<
random
.
uniform
(
0
,
1
):
select_flag
=
False
select_flag
=
False
...
...
ppocr/modeling/architectures/base_model.py
View file @
16c247ac
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
ppocr/modeling/heads/table_att_head.py
View file @
16c247ac
...
@@ -21,13 +21,16 @@ import paddle.nn as nn
...
@@ -21,13 +21,16 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
class
TableAttentionHead
(
nn
.
Layer
):
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
char_num
=
280
self
.
elem_num
=
30
self
.
elem_num
=
30
self
.
max_text_length
=
100
self
.
max_elem_length
=
500
self
.
max_cell_num
=
500
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
...
@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
...
@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
else
:
else
:
if
self
.
in_max_len
==
640
:
if
self
.
in_max_len
==
640
:
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
elif
self
.
in_max_len
==
800
:
elif
self
.
in_max_len
==
800
:
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
else
:
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
...
@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
...
@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
batch_size
=
fea
.
shape
[
0
]
batch_size
=
fea
.
shape
[
0
]
#sp_tokens = targets[2].numpy()
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
max_text_length
,
max_elem_length
,
max_cell_num
=
100
,
800
,
500
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
output_hiddens
=
[]
if
mode
==
'Train'
and
targets
is
not
None
:
if
mode
==
'Train'
and
targets
is
not
None
:
structure
=
targets
[
0
]
structure
=
targets
[
0
]
for
i
in
range
(
max_elem_length
+
1
):
for
i
in
range
(
self
.
max_elem_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
elem_onehots
=
None
elem_onehots
=
None
outputs
=
None
outputs
=
None
alpha
=
None
alpha
=
None
max_elem_length
=
paddle
.
to_tensor
(
max_elem_length
)
max_elem_length
=
paddle
.
to_tensor
(
self
.
max_elem_length
)
i
=
0
i
=
0
while
i
<
max_elem_length
+
1
:
while
i
<
max_elem_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
elem_onehots
=
self
.
_char_to_onehot
(
...
@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
loc_preds
=
F
.
sigmoid
(
loc_preds
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
AttentionGRUCell
(
nn
.
Layer
):
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
super
(
AttentionGRUCell
,
self
).
__init__
()
...
...
ppocr/modeling/necks/table_fpn.py
View file @
16c247ac
# copyright (c) 201
9
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 20
2
1 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
...
@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
in_channels
=
in_channels
[
0
],
in_channels
=
in_channels
[
0
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_51.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
in3_conv
=
nn
.
Conv2D
(
self
.
in3_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
1
],
in_channels
=
in_channels
[
1
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_50.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
in4_conv
=
nn
.
Conv2D
(
self
.
in4_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
2
],
in_channels
=
in_channels
[
2
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_49.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
in5_conv
=
nn
.
Conv2D
(
self
.
in5_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
3
],
in_channels
=
in_channels
[
3
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_48.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p5_conv
=
nn
.
Conv2D
(
self
.
p5_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_52.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p4_conv
=
nn
.
Conv2D
(
self
.
p4_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_53.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p3_conv
=
nn
.
Conv2D
(
self
.
p3_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_54.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p2_conv
=
nn
.
Conv2D
(
self
.
p2_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_55.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
fuse_conv
=
nn
.
Conv2D
(
self
.
fuse_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
*
4
,
in_channels
=
self
.
out_channels
*
4
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
name
=
'conv2d_fuse.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
c2
,
c3
,
c4
,
c5
=
x
c2
,
c3
,
c4
,
c5
=
x
...
...
ppocr/postprocess/rec_postprocess.py
View file @
16c247ac
...
@@ -369,18 +369,6 @@ class TableLabelDecode(object):
...
@@ -369,18 +369,6 @@ class TableLabelDecode(object):
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
return
list_character
def
get_sp_tokens
(
self
):
char_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
sp_tokens
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
])
return
sp_tokens
def
__call__
(
self
,
preds
):
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
loc_preds
=
preds
[
'loc_preds'
]
...
...
tools/export_model.py
View file @
16c247ac
...
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
...
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
)
infer_shape
[
-
1
]
=
100
infer_shape
[
-
1
]
=
100
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
model
=
to_static
(
model
=
to_static
(
model
,
model
,
input_spec
=
[
input_spec
=
[
...
...
tools/infer_table.py
View file @
16c247ac
...
@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
...
@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
img
=
f
.
read
()
img
=
f
.
read
()
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
batch
=
transform
(
data
,
ops
)
sp_tokens
=
post_process_class
.
get_sp_tokens
()
targets
=
[[],
[],
paddle
.
to_tensor
([
sp_tokens
])]
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
,
data
=
targets
,
mode
=
'Test'
)
preds
=
model
(
images
,
data
=
None
,
mode
=
'Test'
)
post_result
=
post_process_class
(
preds
)
post_result
=
post_process_class
(
preds
)
res_html_code
=
post_result
[
'res_html_code'
]
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
res_loc
=
post_result
[
'res_loc'
]
...
...
tools/program.py
View file @
16c247ac
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -276,6 +276,7 @@ def train(config,
...
@@ -276,6 +276,7 @@ def train(config,
valid_dataloader
,
valid_dataloader
,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
"table"
,
use_srn
=
use_srn
)
use_srn
=
use_srn
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment