Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
41a1b292
Commit
41a1b292
authored
Jan 20, 2022
by
Leif
Browse files
Merge remote-tracking branch 'origin/dygraph' into dygraph
parents
9471054e
3d30899b
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1338 additions
and
36 deletions
+1338
-36
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+4
-1
ppocr/metrics/cls_metric.py
ppocr/metrics/cls_metric.py
+3
-2
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+5
-4
ppocr/metrics/table_metric.py
ppocr/metrics/table_metric.py
+5
-4
ppocr/metrics/vqa_token_re_metric.py
ppocr/metrics/vqa_token_re_metric.py
+176
-0
ppocr/metrics/vqa_token_ser_metric.py
ppocr/metrics/vqa_token_ser_metric.py
+47
-0
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+8
-3
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+5
-1
ppocr/modeling/backbones/rec_micronet.py
ppocr/modeling/backbones/rec_micronet.py
+528
-0
ppocr/modeling/backbones/vqa_layoutlm.py
ppocr/modeling/backbones/vqa_layoutlm.py
+125
-0
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+3
-1
ppocr/optimizer/learning_rate.py
ppocr/optimizer/learning_rate.py
+51
-1
ppocr/optimizer/lr_scheduler.py
ppocr/optimizer/lr_scheduler.py
+113
-0
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+35
-0
ppocr/optimizer/regularizer.py
ppocr/optimizer/regularizer.py
+5
-6
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+4
-1
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+51
-0
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+93
-0
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+49
-11
ppocr/utils/utility.py
ppocr/utils/utility.py
+28
-1
No files found.
ppocr/metrics/__init__.py
View file @
41a1b292
...
...
@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from
.distillation_metric
import
DistillationMetric
from
.table_metric
import
TableMetric
from
.kie_metric
import
KIEMetric
from
.vqa_token_ser_metric
import
VQASerTokenMetric
from
.vqa_token_re_metric
import
VQAReTokenMetric
def
build_metric
(
config
):
support_dict
=
[
"DetMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
,
'VQASerTokenMetric'
,
'VQAReTokenMetric'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/metrics/cls_metric.py
View file @
41a1b292
...
...
@@ -16,6 +16,7 @@
class
ClsMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
eps
=
1e-5
self
.
reset
()
def
__call__
(
self
,
pred_label
,
*
args
,
**
kwargs
):
...
...
@@ -28,7 +29,7 @@ class ClsMetric(object):
all_num
+=
1
self
.
correct_num
+=
correct_num
self
.
all_num
+=
all_num
return
{
'acc'
:
correct_num
/
all_num
,
}
return
{
'acc'
:
correct_num
/
(
all_num
+
self
.
eps
)
,
}
def
get_metric
(
self
):
"""
...
...
@@ -36,7 +37,7 @@ class ClsMetric(object):
'acc': 0
}
"""
acc
=
self
.
correct_num
/
self
.
all_num
acc
=
self
.
correct_num
/
(
self
.
all_num
+
self
.
eps
)
self
.
reset
()
return
{
'acc'
:
acc
}
...
...
ppocr/metrics/rec_metric.py
View file @
41a1b292
...
...
@@ -20,6 +20,7 @@ class RecMetric(object):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
is_filter
=
is_filter
self
.
eps
=
1e-5
self
.
reset
()
def
_normalize_text
(
self
,
text
):
...
...
@@ -47,8 +48,8 @@ class RecMetric(object):
self
.
all_num
+=
all_num
self
.
norm_edit_dis
+=
norm_edit_dis
return
{
'acc'
:
correct_num
/
all_num
,
'norm_edit_dis'
:
1
-
norm_edit_dis
/
(
all_num
+
1e-3
)
'acc'
:
correct_num
/
(
all_num
+
self
.
eps
)
,
'norm_edit_dis'
:
1
-
norm_edit_dis
/
(
all_num
+
self
.
eps
)
}
def
get_metric
(
self
):
...
...
@@ -58,8 +59,8 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
acc
=
1.0
*
self
.
correct_num
/
(
self
.
all_num
+
1e-3
)
norm_edit_dis
=
1
-
self
.
norm_edit_dis
/
(
self
.
all_num
+
1e-3
)
acc
=
1.0
*
self
.
correct_num
/
(
self
.
all_num
+
self
.
eps
)
norm_edit_dis
=
1
-
self
.
norm_edit_dis
/
(
self
.
all_num
+
self
.
eps
)
self
.
reset
()
return
{
'acc'
:
acc
,
'norm_edit_dis'
:
norm_edit_dis
}
...
...
ppocr/metrics/table_metric.py
View file @
41a1b292
...
...
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
class
TableMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
eps
=
1e-5
self
.
reset
()
def
__call__
(
self
,
pred
,
batch
,
*
args
,
**
kwargs
):
...
...
@@ -31,9 +34,7 @@ class TableMetric(object):
correct_num
+=
1
self
.
correct_num
+=
correct_num
self
.
all_num
+=
all_num
return
{
'acc'
:
correct_num
*
1.0
/
all_num
,
}
return
{
'acc'
:
correct_num
*
1.0
/
(
all_num
+
self
.
eps
),
}
def
get_metric
(
self
):
"""
...
...
@@ -41,7 +42,7 @@ class TableMetric(object):
'acc': 0,
}
"""
acc
=
1.0
*
self
.
correct_num
/
self
.
all_num
acc
=
1.0
*
self
.
correct_num
/
(
self
.
all_num
+
self
.
eps
)
self
.
reset
()
return
{
'acc'
:
acc
}
...
...
ppocr/metrics/vqa_token_re_metric.py
0 → 100644
View file @
41a1b292
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
__all__
=
[
'KIEMetric'
]
class
VQAReTokenMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'hmean'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
pred_relations
,
relations
,
entities
=
preds
self
.
pred_relations_list
.
extend
(
pred_relations
)
self
.
relations_list
.
extend
(
relations
)
self
.
entities_list
.
extend
(
entities
)
def
get_metric
(
self
):
gt_relations
=
[]
for
b
in
range
(
len
(
self
.
relations_list
)):
rel_sent
=
[]
for
head
,
tail
in
zip
(
self
.
relations_list
[
b
][
"head"
],
self
.
relations_list
[
b
][
"tail"
]):
rel
=
{}
rel
[
"head_id"
]
=
head
rel
[
"head"
]
=
(
self
.
entities_list
[
b
][
"start"
][
rel
[
"head_id"
]],
self
.
entities_list
[
b
][
"end"
][
rel
[
"head_id"
]])
rel
[
"head_type"
]
=
self
.
entities_list
[
b
][
"label"
][
rel
[
"head_id"
]]
rel
[
"tail_id"
]
=
tail
rel
[
"tail"
]
=
(
self
.
entities_list
[
b
][
"start"
][
rel
[
"tail_id"
]],
self
.
entities_list
[
b
][
"end"
][
rel
[
"tail_id"
]])
rel
[
"tail_type"
]
=
self
.
entities_list
[
b
][
"label"
][
rel
[
"tail_id"
]]
rel
[
"type"
]
=
1
rel_sent
.
append
(
rel
)
gt_relations
.
append
(
rel_sent
)
re_metrics
=
self
.
re_score
(
self
.
pred_relations_list
,
gt_relations
,
mode
=
"boundaries"
)
metrics
=
{
"precision"
:
re_metrics
[
"ALL"
][
"p"
],
"recall"
:
re_metrics
[
"ALL"
][
"r"
],
"hmean"
:
re_metrics
[
"ALL"
][
"f1"
],
}
self
.
reset
()
return
metrics
def
reset
(
self
):
self
.
pred_relations_list
=
[]
self
.
relations_list
=
[]
self
.
entities_list
=
[]
def
re_score
(
self
,
pred_relations
,
gt_relations
,
mode
=
"strict"
):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert
mode
in
[
"strict"
,
"boundaries"
]
relation_types
=
[
v
for
v
in
[
0
,
1
]
if
not
v
==
0
]
scores
=
{
rel
:
{
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
}
for
rel
in
relation_types
+
[
"ALL"
]
}
# Count GT relations and Predicted relations
n_sents
=
len
(
gt_relations
)
n_rels
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
gt_relations
])
n_found
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
pred_relations
])
# Count TP, FP and FN per type
for
pred_sent
,
gt_sent
in
zip
(
pred_relations
,
gt_relations
):
for
rel_type
in
relation_types
:
# strict mode takes argument types into account
if
mode
==
"strict"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
# boundaries mode only takes argument spans into account
elif
mode
==
"boundaries"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
scores
[
rel_type
][
"tp"
]
+=
len
(
pred_rels
&
gt_rels
)
scores
[
rel_type
][
"fp"
]
+=
len
(
pred_rels
-
gt_rels
)
scores
[
rel_type
][
"fn"
]
+=
len
(
gt_rels
-
pred_rels
)
# Compute per entity Precision / Recall / F1
for
rel_type
in
scores
.
keys
():
if
scores
[
rel_type
][
"tp"
]:
scores
[
rel_type
][
"p"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fp"
]
+
scores
[
rel_type
][
"tp"
])
scores
[
rel_type
][
"r"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fn"
]
+
scores
[
rel_type
][
"tp"
])
else
:
scores
[
rel_type
][
"p"
],
scores
[
rel_type
][
"r"
]
=
0
,
0
if
not
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]
==
0
:
scores
[
rel_type
][
"f1"
]
=
(
2
*
scores
[
rel_type
][
"p"
]
*
scores
[
rel_type
][
"r"
]
/
(
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]))
else
:
scores
[
rel_type
][
"f1"
]
=
0
# Compute micro F1 Scores
tp
=
sum
([
scores
[
rel_type
][
"tp"
]
for
rel_type
in
relation_types
])
fp
=
sum
([
scores
[
rel_type
][
"fp"
]
for
rel_type
in
relation_types
])
fn
=
sum
([
scores
[
rel_type
][
"fn"
]
for
rel_type
in
relation_types
])
if
tp
:
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
else
:
precision
,
recall
,
f1
=
0
,
0
,
0
scores
[
"ALL"
][
"p"
]
=
precision
scores
[
"ALL"
][
"r"
]
=
recall
scores
[
"ALL"
][
"f1"
]
=
f1
scores
[
"ALL"
][
"tp"
]
=
tp
scores
[
"ALL"
][
"fp"
]
=
fp
scores
[
"ALL"
][
"fn"
]
=
fn
# Compute Macro F1 Scores
scores
[
"ALL"
][
"Macro_f1"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"f1"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_p"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"p"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_r"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"r"
]
for
ent_type
in
relation_types
])
return
scores
ppocr/metrics/vqa_token_ser_metric.py
0 → 100644
View file @
41a1b292
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
__all__
=
[
'KIEMetric'
]
class
VQASerTokenMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'hmean'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
preds
,
labels
=
preds
self
.
pred_list
.
extend
(
preds
)
self
.
gt_list
.
extend
(
labels
)
def
get_metric
(
self
):
from
seqeval.metrics
import
f1_score
,
precision_score
,
recall_score
metircs
=
{
"precision"
:
precision_score
(
self
.
gt_list
,
self
.
pred_list
),
"recall"
:
recall_score
(
self
.
gt_list
,
self
.
pred_list
),
"hmean"
:
f1_score
(
self
.
gt_list
,
self
.
pred_list
),
}
self
.
reset
()
return
metircs
def
reset
(
self
):
self
.
pred_list
=
[]
self
.
gt_list
=
[]
ppocr/modeling/architectures/base_model.py
View file @
41a1b292
...
...
@@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
in_channels
=
self
.
neck
.
out_channels
# # build head, head is need for det, rec and cls
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
if
'Head'
not
in
config
or
config
[
'Head'
]
is
None
:
self
.
use_head
=
False
else
:
self
.
use_head
=
True
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
self
.
return_all_feats
=
config
.
get
(
"return_all_feats"
,
False
)
...
...
@@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
y
[
"neck_out"
]
=
x
x
=
self
.
head
(
x
,
targets
=
data
)
if
self
.
use_head
:
x
=
self
.
head
(
x
,
targets
=
data
)
if
isinstance
(
x
,
dict
):
y
.
update
(
x
)
else
:
...
...
ppocr/modeling/backbones/__init__.py
View file @
41a1b292
...
...
@@ -29,9 +29,10 @@ def build_backbone(config, model_type):
from
.rec_nrtr_mtb
import
MTB
from
.rec_resnet_31
import
ResNet31
from
.rec_resnet_aster
import
ResNet_ASTER
from
.rec_micronet
import
MicroNet
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
,
"ResNet_ASTER"
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
]
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
@@ -43,6 +44,9 @@ def build_backbone(config, model_type):
from
.table_resnet_vd
import
ResNet
from
.table_mobilenet_v3
import
MobileNetV3
support_dict
=
[
"ResNet"
,
"MobileNetV3"
]
elif
model_type
==
'vqa'
:
from
.vqa_layoutlm
import
LayoutLMForSer
,
LayoutXLMForSer
,
LayoutXLMForRe
support_dict
=
[
"LayoutLMForSer"
,
"LayoutXLMForSer"
,
'LayoutXLMForRe'
]
else
:
raise
NotImplementedError
...
...
ppocr/modeling/backbones/rec_micronet.py
0 → 100644
View file @
41a1b292
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py
https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
from
ppocr.modeling.backbones.det_mobilenet_v3
import
make_divisible
M0_cfgs
=
[
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r
[
2
,
1
,
8
,
3
,
2
,
2
,
0
,
4
,
8
,
2
,
2
,
2
,
0
,
1
,
1
],
[
2
,
1
,
12
,
3
,
2
,
2
,
0
,
8
,
12
,
4
,
4
,
2
,
2
,
1
,
1
],
[
2
,
1
,
16
,
5
,
2
,
2
,
0
,
12
,
16
,
4
,
4
,
2
,
2
,
1
,
1
],
[
1
,
1
,
32
,
5
,
1
,
4
,
4
,
4
,
32
,
4
,
4
,
2
,
2
,
1
,
1
],
[
2
,
1
,
64
,
5
,
1
,
4
,
8
,
8
,
64
,
8
,
8
,
2
,
2
,
1
,
1
],
[
1
,
1
,
96
,
3
,
1
,
4
,
8
,
8
,
96
,
8
,
8
,
2
,
2
,
1
,
2
],
[
1
,
1
,
384
,
3
,
1
,
4
,
12
,
12
,
0
,
0
,
0
,
2
,
2
,
1
,
2
],
]
M1_cfgs
=
[
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[
2
,
1
,
8
,
3
,
2
,
2
,
0
,
6
,
8
,
2
,
2
,
2
,
0
,
1
,
1
],
[
2
,
1
,
16
,
3
,
2
,
2
,
0
,
8
,
16
,
4
,
4
,
2
,
2
,
1
,
1
],
[
2
,
1
,
16
,
5
,
2
,
2
,
0
,
16
,
16
,
4
,
4
,
2
,
2
,
1
,
1
],
[
1
,
1
,
32
,
5
,
1
,
6
,
4
,
4
,
32
,
4
,
4
,
2
,
2
,
1
,
1
],
[
2
,
1
,
64
,
5
,
1
,
6
,
8
,
8
,
64
,
8
,
8
,
2
,
2
,
1
,
1
],
[
1
,
1
,
96
,
3
,
1
,
6
,
8
,
8
,
96
,
8
,
8
,
2
,
2
,
1
,
2
],
[
1
,
1
,
576
,
3
,
1
,
6
,
12
,
12
,
0
,
0
,
0
,
2
,
2
,
1
,
2
],
]
M2_cfgs
=
[
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[
2
,
1
,
12
,
3
,
2
,
2
,
0
,
8
,
12
,
4
,
4
,
2
,
0
,
1
,
1
],
[
2
,
1
,
16
,
3
,
2
,
2
,
0
,
12
,
16
,
4
,
4
,
2
,
2
,
1
,
1
],
[
1
,
1
,
24
,
3
,
2
,
2
,
0
,
16
,
24
,
4
,
4
,
2
,
2
,
1
,
1
],
[
2
,
1
,
32
,
5
,
1
,
6
,
6
,
6
,
32
,
4
,
4
,
2
,
2
,
1
,
1
],
[
1
,
1
,
32
,
5
,
1
,
6
,
8
,
8
,
32
,
4
,
4
,
2
,
2
,
1
,
2
],
[
1
,
1
,
64
,
5
,
1
,
6
,
8
,
8
,
64
,
8
,
8
,
2
,
2
,
1
,
2
],
[
2
,
1
,
96
,
5
,
1
,
6
,
8
,
8
,
96
,
8
,
8
,
2
,
2
,
1
,
2
],
[
1
,
1
,
128
,
3
,
1
,
6
,
12
,
12
,
128
,
8
,
8
,
2
,
2
,
1
,
2
],
[
1
,
1
,
768
,
3
,
1
,
6
,
16
,
16
,
0
,
0
,
0
,
2
,
2
,
1
,
2
],
]
M3_cfgs
=
[
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[
2
,
1
,
16
,
3
,
2
,
2
,
0
,
12
,
16
,
4
,
4
,
0
,
2
,
0
,
1
],
[
2
,
1
,
24
,
3
,
2
,
2
,
0
,
16
,
24
,
4
,
4
,
0
,
2
,
0
,
1
],
[
1
,
1
,
24
,
3
,
2
,
2
,
0
,
24
,
24
,
4
,
4
,
0
,
2
,
0
,
1
],
[
2
,
1
,
32
,
5
,
1
,
6
,
6
,
6
,
32
,
4
,
4
,
0
,
2
,
0
,
1
],
[
1
,
1
,
32
,
5
,
1
,
6
,
8
,
8
,
32
,
4
,
4
,
0
,
2
,
0
,
2
],
[
1
,
1
,
64
,
5
,
1
,
6
,
8
,
8
,
48
,
8
,
8
,
0
,
2
,
0
,
2
],
[
1
,
1
,
80
,
5
,
1
,
6
,
8
,
8
,
80
,
8
,
8
,
0
,
2
,
0
,
2
],
[
1
,
1
,
80
,
5
,
1
,
6
,
10
,
10
,
80
,
8
,
8
,
0
,
2
,
0
,
2
],
[
1
,
1
,
120
,
5
,
1
,
6
,
10
,
10
,
120
,
10
,
10
,
0
,
2
,
0
,
2
],
[
1
,
1
,
120
,
5
,
1
,
6
,
12
,
12
,
120
,
10
,
10
,
0
,
2
,
0
,
2
],
[
1
,
1
,
144
,
3
,
1
,
6
,
12
,
12
,
144
,
12
,
12
,
0
,
2
,
0
,
2
],
[
1
,
1
,
432
,
3
,
1
,
3
,
12
,
12
,
0
,
0
,
0
,
0
,
2
,
0
,
2
],
]
def
get_micronet_config
(
mode
):
return
eval
(
mode
+
'_cfgs'
)
class
MaxGroupPooling
(
nn
.
Layer
):
def
__init__
(
self
,
channel_per_group
=
2
):
super
(
MaxGroupPooling
,
self
).
__init__
()
self
.
channel_per_group
=
channel_per_group
def
forward
(
self
,
x
):
if
self
.
channel_per_group
==
1
:
return
x
# max op
b
,
c
,
h
,
w
=
x
.
shape
# reshape
y
=
paddle
.
reshape
(
x
,
[
b
,
c
//
self
.
channel_per_group
,
-
1
,
h
,
w
])
out
=
paddle
.
max
(
y
,
axis
=
2
)
return
out
class
SpatialSepConvSF
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
oups
,
kernel_size
,
stride
):
super
(
SpatialSepConvSF
,
self
).
__init__
()
oup1
,
oup2
=
oups
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
inp
,
oup1
,
(
kernel_size
,
1
),
(
stride
,
1
),
(
kernel_size
//
2
,
0
),
bias_attr
=
False
,
groups
=
1
),
nn
.
BatchNorm2D
(
oup1
),
nn
.
Conv2D
(
oup1
,
oup1
*
oup2
,
(
1
,
kernel_size
),
(
1
,
stride
),
(
0
,
kernel_size
//
2
),
bias_attr
=
False
,
groups
=
oup1
),
nn
.
BatchNorm2D
(
oup1
*
oup2
),
ChannelShuffle
(
oup1
),
)
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
return
out
class
ChannelShuffle
(
nn
.
Layer
):
def
__init__
(
self
,
groups
):
super
(
ChannelShuffle
,
self
).
__init__
()
self
.
groups
=
groups
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
channels_per_group
=
c
//
self
.
groups
# reshape
x
=
paddle
.
reshape
(
x
,
[
b
,
self
.
groups
,
channels_per_group
,
h
,
w
])
x
=
paddle
.
transpose
(
x
,
(
0
,
2
,
1
,
3
,
4
))
out
=
paddle
.
reshape
(
x
,
[
b
,
-
1
,
h
,
w
])
return
out
class
StemLayer
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
oup
,
stride
,
groups
=
(
4
,
4
)):
super
(
StemLayer
,
self
).
__init__
()
g1
,
g2
=
groups
self
.
stem
=
nn
.
Sequential
(
SpatialSepConvSF
(
inp
,
groups
,
3
,
stride
),
MaxGroupPooling
(
2
)
if
g1
*
g2
==
2
*
oup
else
nn
.
ReLU6
())
def
forward
(
self
,
x
):
out
=
self
.
stem
(
x
)
return
out
class
DepthSpatialSepConv
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
expand
,
kernel_size
,
stride
):
super
(
DepthSpatialSepConv
,
self
).
__init__
()
exp1
,
exp2
=
expand
hidden_dim
=
inp
*
exp1
oup
=
inp
*
exp1
*
exp2
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
inp
,
inp
*
exp1
,
(
kernel_size
,
1
),
(
stride
,
1
),
(
kernel_size
//
2
,
0
),
bias_attr
=
False
,
groups
=
inp
),
nn
.
BatchNorm2D
(
inp
*
exp1
),
nn
.
Conv2D
(
hidden_dim
,
oup
,
(
1
,
kernel_size
),
1
,
(
0
,
kernel_size
//
2
),
bias_attr
=
False
,
groups
=
hidden_dim
),
nn
.
BatchNorm2D
(
oup
))
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
class
GroupConv
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
oup
,
groups
=
2
):
super
(
GroupConv
,
self
).
__init__
()
self
.
inp
=
inp
self
.
oup
=
oup
self
.
groups
=
groups
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
inp
,
oup
,
1
,
1
,
0
,
bias_attr
=
False
,
groups
=
self
.
groups
[
0
]),
nn
.
BatchNorm2D
(
oup
))
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
class
DepthConv
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
oup
,
kernel_size
,
stride
):
super
(
DepthConv
,
self
).
__init__
()
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
inp
,
oup
,
kernel_size
,
stride
,
kernel_size
//
2
,
bias_attr
=
False
,
groups
=
inp
),
nn
.
BatchNorm2D
(
oup
))
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
return
out
class
DYShiftMax
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
oup
,
reduction
=
4
,
act_max
=
1.0
,
act_relu
=
True
,
init_a
=
[
0.0
,
0.0
],
init_b
=
[
0.0
,
0.0
],
relu_before_pool
=
False
,
g
=
None
,
expansion
=
False
):
super
(
DYShiftMax
,
self
).
__init__
()
self
.
oup
=
oup
self
.
act_max
=
act_max
*
2
self
.
act_relu
=
act_relu
self
.
avg_pool
=
nn
.
Sequential
(
nn
.
ReLU
()
if
relu_before_pool
==
True
else
nn
.
Sequential
(),
nn
.
AdaptiveAvgPool2D
(
1
))
self
.
exp
=
4
if
act_relu
else
2
self
.
init_a
=
init_a
self
.
init_b
=
init_b
# determine squeeze
squeeze
=
make_divisible
(
inp
//
reduction
,
4
)
if
squeeze
<
4
:
squeeze
=
4
self
.
fc
=
nn
.
Sequential
(
nn
.
Linear
(
inp
,
squeeze
),
nn
.
ReLU
(),
nn
.
Linear
(
squeeze
,
oup
*
self
.
exp
),
nn
.
Hardsigmoid
())
if
g
is
None
:
g
=
1
self
.
g
=
g
[
1
]
if
self
.
g
!=
1
and
expansion
:
self
.
g
=
inp
//
self
.
g
self
.
gc
=
inp
//
self
.
g
index
=
paddle
.
to_tensor
([
range
(
inp
)])
index
=
paddle
.
reshape
(
index
,
[
1
,
inp
,
1
,
1
])
index
=
paddle
.
reshape
(
index
,
[
1
,
self
.
g
,
self
.
gc
,
1
,
1
])
indexgs
=
paddle
.
split
(
index
,
[
1
,
self
.
g
-
1
],
axis
=
1
)
indexgs
=
paddle
.
concat
((
indexgs
[
1
],
indexgs
[
0
]),
axis
=
1
)
indexs
=
paddle
.
split
(
indexgs
,
[
1
,
self
.
gc
-
1
],
axis
=
2
)
indexs
=
paddle
.
concat
((
indexs
[
1
],
indexs
[
0
]),
axis
=
2
)
self
.
index
=
paddle
.
reshape
(
indexs
,
[
inp
])
self
.
expansion
=
expansion
def
forward
(
self
,
x
):
x_in
=
x
x_out
=
x
b
,
c
,
_
,
_
=
x_in
.
shape
y
=
self
.
avg_pool
(
x_in
)
y
=
paddle
.
reshape
(
y
,
[
b
,
c
])
y
=
self
.
fc
(
y
)
y
=
paddle
.
reshape
(
y
,
[
b
,
self
.
oup
*
self
.
exp
,
1
,
1
])
y
=
(
y
-
0.5
)
*
self
.
act_max
n2
,
c2
,
h2
,
w2
=
x_out
.
shape
x2
=
paddle
.
to_tensor
(
x_out
.
numpy
()[:,
self
.
index
.
numpy
(),
:,
:])
if
self
.
exp
==
4
:
temp
=
y
.
shape
a1
,
b1
,
a2
,
b2
=
paddle
.
split
(
y
,
temp
[
1
]
//
self
.
oup
,
axis
=
1
)
a1
=
a1
+
self
.
init_a
[
0
]
a2
=
a2
+
self
.
init_a
[
1
]
b1
=
b1
+
self
.
init_b
[
0
]
b2
=
b2
+
self
.
init_b
[
1
]
z1
=
x_out
*
a1
+
x2
*
b1
z2
=
x_out
*
a2
+
x2
*
b2
out
=
paddle
.
maximum
(
z1
,
z2
)
elif
self
.
exp
==
2
:
temp
=
y
.
shape
a1
,
b1
=
paddle
.
split
(
y
,
temp
[
1
]
//
self
.
oup
,
axis
=
1
)
a1
=
a1
+
self
.
init_a
[
0
]
b1
=
b1
+
self
.
init_b
[
0
]
out
=
x_out
*
a1
+
x2
*
b1
return
out
class
DYMicroBlock
(
nn
.
Layer
):
def
__init__
(
self
,
inp
,
oup
,
kernel_size
=
3
,
stride
=
1
,
ch_exp
=
(
2
,
2
),
ch_per_group
=
4
,
groups_1x1
=
(
1
,
1
),
depthsep
=
True
,
shuffle
=
False
,
activation_cfg
=
None
):
super
(
DYMicroBlock
,
self
).
__init__
()
self
.
identity
=
stride
==
1
and
inp
==
oup
y1
,
y2
,
y3
=
activation_cfg
[
'dy'
]
act_reduction
=
8
*
activation_cfg
[
'ratio'
]
init_a
=
activation_cfg
[
'init_a'
]
init_b
=
activation_cfg
[
'init_b'
]
t1
=
ch_exp
gs1
=
ch_per_group
hidden_fft
,
g1
,
g2
=
groups_1x1
hidden_dim2
=
inp
*
t1
[
0
]
*
t1
[
1
]
if
gs1
[
0
]
==
0
:
self
.
layers
=
nn
.
Sequential
(
DepthSpatialSepConv
(
inp
,
t1
,
kernel_size
,
stride
),
DYShiftMax
(
hidden_dim2
,
hidden_dim2
,
act_max
=
2.0
,
act_relu
=
True
if
y2
==
2
else
False
,
init_a
=
init_a
,
reduction
=
act_reduction
,
init_b
=
init_b
,
g
=
gs1
,
expansion
=
False
)
if
y2
>
0
else
nn
.
ReLU6
(),
ChannelShuffle
(
gs1
[
1
])
if
shuffle
else
nn
.
Sequential
(),
ChannelShuffle
(
hidden_dim2
//
2
)
if
shuffle
and
y2
!=
0
else
nn
.
Sequential
(),
GroupConv
(
hidden_dim2
,
oup
,
(
g1
,
g2
)),
DYShiftMax
(
oup
,
oup
,
act_max
=
2.0
,
act_relu
=
False
,
init_a
=
[
1.0
,
0.0
],
reduction
=
act_reduction
//
2
,
init_b
=
[
0.0
,
0.0
],
g
=
(
g1
,
g2
),
expansion
=
False
)
if
y3
>
0
else
nn
.
Sequential
(),
ChannelShuffle
(
g2
)
if
shuffle
else
nn
.
Sequential
(),
ChannelShuffle
(
oup
//
2
)
if
shuffle
and
oup
%
2
==
0
and
y3
!=
0
else
nn
.
Sequential
(),
)
elif
g2
==
0
:
self
.
layers
=
nn
.
Sequential
(
GroupConv
(
inp
,
hidden_dim2
,
gs1
),
DYShiftMax
(
hidden_dim2
,
hidden_dim2
,
act_max
=
2.0
,
act_relu
=
False
,
init_a
=
[
1.0
,
0.0
],
reduction
=
act_reduction
,
init_b
=
[
0.0
,
0.0
],
g
=
gs1
,
expansion
=
False
)
if
y3
>
0
else
nn
.
Sequential
(),
)
else
:
self
.
layers
=
nn
.
Sequential
(
GroupConv
(
inp
,
hidden_dim2
,
gs1
),
DYShiftMax
(
hidden_dim2
,
hidden_dim2
,
act_max
=
2.0
,
act_relu
=
True
if
y1
==
2
else
False
,
init_a
=
init_a
,
reduction
=
act_reduction
,
init_b
=
init_b
,
g
=
gs1
,
expansion
=
False
)
if
y1
>
0
else
nn
.
ReLU6
(),
ChannelShuffle
(
gs1
[
1
])
if
shuffle
else
nn
.
Sequential
(),
DepthSpatialSepConv
(
hidden_dim2
,
(
1
,
1
),
kernel_size
,
stride
)
if
depthsep
else
DepthConv
(
hidden_dim2
,
hidden_dim2
,
kernel_size
,
stride
),
nn
.
Sequential
(),
DYShiftMax
(
hidden_dim2
,
hidden_dim2
,
act_max
=
2.0
,
act_relu
=
True
if
y2
==
2
else
False
,
init_a
=
init_a
,
reduction
=
act_reduction
,
init_b
=
init_b
,
g
=
gs1
,
expansion
=
True
)
if
y2
>
0
else
nn
.
ReLU6
(),
ChannelShuffle
(
hidden_dim2
//
4
)
if
shuffle
and
y1
!=
0
and
y2
!=
0
else
nn
.
Sequential
()
if
y1
==
0
and
y2
==
0
else
ChannelShuffle
(
hidden_dim2
//
2
),
GroupConv
(
hidden_dim2
,
oup
,
(
g1
,
g2
)),
DYShiftMax
(
oup
,
oup
,
act_max
=
2.0
,
act_relu
=
False
,
init_a
=
[
1.0
,
0.0
],
reduction
=
act_reduction
//
2
if
oup
<
hidden_dim2
else
act_reduction
,
init_b
=
[
0.0
,
0.0
],
g
=
(
g1
,
g2
),
expansion
=
False
)
if
y3
>
0
else
nn
.
Sequential
(),
ChannelShuffle
(
g2
)
if
shuffle
else
nn
.
Sequential
(),
ChannelShuffle
(
oup
//
2
)
if
shuffle
and
y3
!=
0
else
nn
.
Sequential
(),
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
layers
(
x
)
if
self
.
identity
:
out
=
out
+
identity
return
out
class
MicroNet
(
nn
.
Layer
):
"""
the MicroNet backbone network for recognition module.
Args:
mode(str): {'M0', 'M1', 'M2', 'M3'}
Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
Default: 'M3'.
"""
def
__init__
(
self
,
mode
=
'M3'
,
**
kwargs
):
super
(
MicroNet
,
self
).
__init__
()
self
.
cfgs
=
get_micronet_config
(
mode
)
activation_cfg
=
{}
if
mode
==
'M0'
:
input_channel
=
4
stem_groups
=
2
,
2
out_ch
=
384
activation_cfg
[
'init_a'
]
=
1.0
,
1.0
activation_cfg
[
'init_b'
]
=
0.0
,
0.0
elif
mode
==
'M1'
:
input_channel
=
6
stem_groups
=
3
,
2
out_ch
=
576
activation_cfg
[
'init_a'
]
=
1.0
,
1.0
activation_cfg
[
'init_b'
]
=
0.0
,
0.0
elif
mode
==
'M2'
:
input_channel
=
8
stem_groups
=
4
,
2
out_ch
=
768
activation_cfg
[
'init_a'
]
=
1.0
,
1.0
activation_cfg
[
'init_b'
]
=
0.0
,
0.0
elif
mode
==
'M3'
:
input_channel
=
12
stem_groups
=
4
,
3
out_ch
=
432
activation_cfg
[
'init_a'
]
=
1.0
,
0.5
activation_cfg
[
'init_b'
]
=
0.0
,
0.5
else
:
raise
NotImplementedError
(
"mode["
+
mode
+
"_model] is not implemented!"
)
layers
=
[
StemLayer
(
3
,
input_channel
,
stride
=
2
,
groups
=
stem_groups
)]
for
idx
,
val
in
enumerate
(
self
.
cfgs
):
s
,
n
,
c
,
ks
,
c1
,
c2
,
g1
,
g2
,
c3
,
g3
,
g4
,
y1
,
y2
,
y3
,
r
=
val
t1
=
(
c1
,
c2
)
gs1
=
(
g1
,
g2
)
gs2
=
(
c3
,
g3
,
g4
)
activation_cfg
[
'dy'
]
=
[
y1
,
y2
,
y3
]
activation_cfg
[
'ratio'
]
=
r
output_channel
=
c
layers
.
append
(
DYMicroBlock
(
input_channel
,
output_channel
,
kernel_size
=
ks
,
stride
=
s
,
ch_exp
=
t1
,
ch_per_group
=
gs1
,
groups_1x1
=
gs2
,
depthsep
=
True
,
shuffle
=
True
,
activation_cfg
=
activation_cfg
,
))
input_channel
=
output_channel
for
i
in
range
(
1
,
n
):
layers
.
append
(
DYMicroBlock
(
input_channel
,
output_channel
,
kernel_size
=
ks
,
stride
=
1
,
ch_exp
=
t1
,
ch_per_group
=
gs1
,
groups_1x1
=
gs2
,
depthsep
=
True
,
shuffle
=
True
,
activation_cfg
=
activation_cfg
,
))
input_channel
=
output_channel
self
.
features
=
nn
.
Sequential
(
*
layers
)
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_channels
=
make_divisible
(
out_ch
)
def
forward
(
self
,
x
):
x
=
self
.
features
(
x
)
x
=
self
.
pool
(
x
)
return
x
ppocr/modeling/backbones/vqa_layoutlm.py
0 → 100644
View file @
41a1b292
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
paddle
import
nn
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMForTokenClassification
,
LayoutXLMForRelationExtraction
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMForTokenClassification
__all__
=
[
"LayoutXLMForSer"
,
'LayoutLMForSer'
]
pretrained_model_dict
=
{
LayoutXLMModel
:
'layoutxlm-base-uncased'
,
LayoutLMModel
:
'layoutlm-base-uncased'
}
class
NLPBaseModel
(
nn
.
Layer
):
def
__init__
(
self
,
base_model_class
,
model_class
,
type
=
'ser'
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
NLPBaseModel
,
self
).
__init__
()
if
checkpoints
is
not
None
:
self
.
model
=
model_class
.
from_pretrained
(
checkpoints
)
else
:
pretrained_model_name
=
pretrained_model_dict
[
base_model_class
]
if
pretrained
:
base_model
=
base_model_class
.
from_pretrained
(
pretrained_model_name
)
else
:
base_model
=
base_model_class
(
**
base_model_class
.
pretrained_init_configuration
[
pretrained_model_name
])
if
type
==
'ser'
:
self
.
model
=
model_class
(
base_model
,
num_classes
=
kwargs
[
'num_classes'
],
dropout
=
None
)
else
:
self
.
model
=
model_class
(
base_model
,
dropout
=
None
)
self
.
out_channels
=
1
class
LayoutXLMForSer
(
NLPBaseModel
):
def
__init__
(
self
,
num_classes
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutXLMForSer
,
self
).
__init__
(
LayoutXLMModel
,
LayoutXLMForTokenClassification
,
'ser'
,
pretrained
,
checkpoints
,
num_classes
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
2
],
image
=
x
[
3
],
attention_mask
=
x
[
4
],
token_type_ids
=
x
[
5
],
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
)
return
x
[
0
]
class
LayoutLMForSer
(
NLPBaseModel
):
def
__init__
(
self
,
num_classes
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutLMForSer
,
self
).
__init__
(
LayoutLMModel
,
LayoutLMForTokenClassification
,
'ser'
,
pretrained
,
checkpoints
,
num_classes
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
2
],
attention_mask
=
x
[
4
],
token_type_ids
=
x
[
5
],
position_ids
=
None
,
output_hidden_states
=
False
)
return
x
class
LayoutXLMForRe
(
NLPBaseModel
):
def
__init__
(
self
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutXLMForRe
,
self
).
__init__
(
LayoutXLMModel
,
LayoutXLMForRelationExtraction
,
're'
,
pretrained
,
checkpoints
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
1
],
labels
=
None
,
image
=
x
[
2
],
attention_mask
=
x
[
3
],
token_type_ids
=
x
[
4
],
position_ids
=
None
,
head_mask
=
None
,
entities
=
x
[
5
],
relations
=
x
[
6
])
return
x
ppocr/optimizer/__init__.py
View file @
41a1b292
...
...
@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
reg_config
=
config
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg_name
=
reg_config
.
pop
(
'name'
)
if
not
hasattr
(
regularizer
,
reg_name
):
reg_name
+=
'Decay'
reg
=
getattr
(
regularizer
,
reg_name
)(
**
reg_config
)()
else
:
reg
=
None
...
...
ppocr/optimizer/learning_rate.py
View file @
41a1b292
...
...
@@ -18,7 +18,7 @@ from __future__ import print_function
from
__future__
import
unicode_literals
from
paddle.optimizer
import
lr
from
.lr_scheduler
import
CyclicalCosineDecay
from
.lr_scheduler
import
CyclicalCosineDecay
,
OneCycleDecay
class
Linear
(
object
):
...
...
@@ -226,3 +226,53 @@ class CyclicalCosine(object):
end_lr
=
self
.
learning_rate
,
last_epoch
=
self
.
last_epoch
)
return
learning_rate
class
OneCycle
(
object
):
"""
One Cycle learning rate decay
Args:
max_lr(float): Upper learning rate boundaries
epochs(int): total training epochs
step_each_epoch(int): steps each epoch
anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
Default: ‘cos’
three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def
__init__
(
self
,
max_lr
,
epochs
,
step_each_epoch
,
anneal_strategy
=
'cos'
,
three_phase
=
False
,
warmup_epoch
=
0
,
last_epoch
=-
1
,
**
kwargs
):
super
(
OneCycle
,
self
).
__init__
()
self
.
max_lr
=
max_lr
self
.
epochs
=
epochs
self
.
steps_per_epoch
=
step_each_epoch
self
.
anneal_strategy
=
anneal_strategy
self
.
three_phase
=
three_phase
self
.
last_epoch
=
last_epoch
self
.
warmup_epoch
=
round
(
warmup_epoch
*
step_each_epoch
)
def
__call__
(
self
):
learning_rate
=
OneCycleDecay
(
max_lr
=
self
.
max_lr
,
epochs
=
self
.
epochs
,
steps_per_epoch
=
self
.
steps_per_epoch
,
anneal_strategy
=
self
.
anneal_strategy
,
three_phase
=
self
.
three_phase
,
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
.
LinearWarmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
end_lr
=
self
.
max_lr
,
last_epoch
=
self
.
last_epoch
)
return
learning_rate
\ No newline at end of file
ppocr/optimizer/lr_scheduler.py
View file @
41a1b292
...
...
@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler):
lr
=
self
.
eta_min
+
0.5
*
(
self
.
base_lr
-
self
.
eta_min
)
*
\
(
1
+
math
.
cos
(
math
.
pi
*
reletive_epoch
/
self
.
cycle
))
return
lr
class
OneCycleDecay
(
LRScheduler
):
"""
One Cycle learning rate decay
A learning rate which can be referred in https://arxiv.org/abs/1708.07120
Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
"""
def
__init__
(
self
,
max_lr
,
epochs
=
None
,
steps_per_epoch
=
None
,
pct_start
=
0.3
,
anneal_strategy
=
'cos'
,
div_factor
=
25.
,
final_div_factor
=
1e4
,
three_phase
=
False
,
last_epoch
=-
1
,
verbose
=
False
):
# Validate total_steps
if
epochs
<=
0
or
not
isinstance
(
epochs
,
int
):
raise
ValueError
(
"Expected positive integer epochs, but got {}"
.
format
(
epochs
))
if
steps_per_epoch
<=
0
or
not
isinstance
(
steps_per_epoch
,
int
):
raise
ValueError
(
"Expected positive integer steps_per_epoch, but got {}"
.
format
(
steps_per_epoch
))
self
.
total_steps
=
epochs
*
steps_per_epoch
self
.
max_lr
=
max_lr
self
.
initial_lr
=
self
.
max_lr
/
div_factor
self
.
min_lr
=
self
.
initial_lr
/
final_div_factor
if
three_phase
:
self
.
_schedule_phases
=
[
{
'end_step'
:
float
(
pct_start
*
self
.
total_steps
)
-
1
,
'start_lr'
:
self
.
initial_lr
,
'end_lr'
:
self
.
max_lr
,
},
{
'end_step'
:
float
(
2
*
pct_start
*
self
.
total_steps
)
-
2
,
'start_lr'
:
self
.
max_lr
,
'end_lr'
:
self
.
initial_lr
,
},
{
'end_step'
:
self
.
total_steps
-
1
,
'start_lr'
:
self
.
initial_lr
,
'end_lr'
:
self
.
min_lr
,
},
]
else
:
self
.
_schedule_phases
=
[
{
'end_step'
:
float
(
pct_start
*
self
.
total_steps
)
-
1
,
'start_lr'
:
self
.
initial_lr
,
'end_lr'
:
self
.
max_lr
,
},
{
'end_step'
:
self
.
total_steps
-
1
,
'start_lr'
:
self
.
max_lr
,
'end_lr'
:
self
.
min_lr
,
},
]
# Validate pct_start
if
pct_start
<
0
or
pct_start
>
1
or
not
isinstance
(
pct_start
,
float
):
raise
ValueError
(
"Expected float between 0 and 1 pct_start, but got {}"
.
format
(
pct_start
))
# Validate anneal_strategy
if
anneal_strategy
not
in
[
'cos'
,
'linear'
]:
raise
ValueError
(
"anneal_strategy must by one of 'cos' or 'linear', instead got {}"
.
format
(
anneal_strategy
))
elif
anneal_strategy
==
'cos'
:
self
.
anneal_func
=
self
.
_annealing_cos
elif
anneal_strategy
==
'linear'
:
self
.
anneal_func
=
self
.
_annealing_linear
super
(
OneCycleDecay
,
self
).
__init__
(
max_lr
,
last_epoch
,
verbose
)
def
_annealing_cos
(
self
,
start
,
end
,
pct
):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out
=
math
.
cos
(
math
.
pi
*
pct
)
+
1
return
end
+
(
start
-
end
)
/
2.0
*
cos_out
def
_annealing_linear
(
self
,
start
,
end
,
pct
):
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return
(
end
-
start
)
*
pct
+
start
def
get_lr
(
self
):
computed_lr
=
0.0
step_num
=
self
.
last_epoch
if
step_num
>
self
.
total_steps
:
raise
ValueError
(
"Tried to step {} times. The specified number of total steps is {}"
.
format
(
step_num
+
1
,
self
.
total_steps
))
start_step
=
0
for
i
,
phase
in
enumerate
(
self
.
_schedule_phases
):
end_step
=
phase
[
'end_step'
]
if
step_num
<=
end_step
or
i
==
len
(
self
.
_schedule_phases
)
-
1
:
pct
=
(
step_num
-
start_step
)
/
(
end_step
-
start_step
)
computed_lr
=
self
.
anneal_func
(
phase
[
'start_lr'
],
phase
[
'end_lr'
],
pct
)
break
start_step
=
phase
[
'end_step'
]
return
computed_lr
ppocr/optimizer/optimizer.py
View file @
41a1b292
...
...
@@ -158,3 +158,38 @@ class Adadelta(object):
name
=
self
.
name
,
parameters
=
parameters
)
return
opt
class
AdamW
(
object
):
def
__init__
(
self
,
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
weight_decay
=
0.01
,
grad_clip
=
None
,
name
=
None
,
lazy_mode
=
False
,
**
kwargs
):
self
.
learning_rate
=
learning_rate
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
epsilon
=
epsilon
self
.
learning_rate
=
learning_rate
self
.
weight_decay
=
0.01
if
weight_decay
is
None
else
weight_decay
self
.
grad_clip
=
grad_clip
self
.
name
=
name
self
.
lazy_mode
=
lazy_mode
def
__call__
(
self
,
parameters
):
opt
=
optim
.
AdamW
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
beta2
=
self
.
beta2
,
epsilon
=
self
.
epsilon
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
lazy_mode
=
self
.
lazy_mode
,
parameters
=
parameters
)
return
opt
ppocr/optimizer/regularizer.py
View file @
41a1b292
...
...
@@ -29,24 +29,23 @@ class L1Decay(object):
def
__init__
(
self
,
factor
=
0.0
):
super
(
L1Decay
,
self
).
__init__
()
self
.
regularization_
coeff
=
factor
self
.
coeff
=
factor
def
__call__
(
self
):
reg
=
paddle
.
regularizer
.
L1Decay
(
self
.
regularization_
coeff
)
reg
=
paddle
.
regularizer
.
L1Decay
(
self
.
coeff
)
return
reg
class
L2Decay
(
object
):
"""
L2 Weight Decay Regularization, which
encourages the weights to be sparse
.
L2 Weight Decay Regularization, which
helps to prevent the model over-fitting
.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def
__init__
(
self
,
factor
=
0.0
):
super
(
L2Decay
,
self
).
__init__
()
self
.
regularization_
coeff
=
factor
self
.
coeff
=
float
(
factor
)
def
__call__
(
self
):
reg
=
paddle
.
regularizer
.
L2Decay
(
self
.
regularization_coeff
)
return
reg
return
self
.
coeff
\ No newline at end of file
ppocr/postprocess/__init__.py
View file @
41a1b292
...
...
@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
from
.vqa_token_re_layoutlm_postprocess
import
VQAReTokenLayoutLMPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
...
...
@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
0 → 100644
View file @
41a1b292
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
class
VQAReTokenLayoutLMPostProcess
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
**
kwargs
):
super
(
VQAReTokenLayoutLMPostProcess
,
self
).
__init__
()
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
label
is
not
None
:
return
self
.
_metric
(
preds
,
label
)
else
:
return
self
.
_infer
(
preds
,
*
args
,
**
kwargs
)
def
_metric
(
self
,
preds
,
label
):
return
preds
[
'pred_relations'
],
label
[
6
],
label
[
5
]
def
_infer
(
self
,
preds
,
*
args
,
**
kwargs
):
ser_results
=
kwargs
[
'ser_results'
]
entity_idx_dict_batch
=
kwargs
[
'entity_idx_dict_batch'
]
pred_relations
=
preds
[
'pred_relations'
]
# merge relations and ocr info
results
=
[]
for
pred_relation
,
ser_result
,
entity_idx_dict
in
zip
(
pred_relations
,
ser_results
,
entity_idx_dict_batch
):
result
=
[]
used_tail_id
=
[]
for
relation
in
pred_relation
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ser_result
[
entity_idx_dict
[
relation
[
'head_id'
]]]
ocr_info_tail
=
ser_result
[
entity_idx_dict
[
relation
[
'tail_id'
]]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
results
.
append
(
result
)
return
results
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
0 → 100644
View file @
41a1b292
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
ppocr.utils.utility
import
load_vqa_bio_label_maps
class
VQASerTokenLayoutLMPostProcess
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
class_path
,
**
kwargs
):
super
(
VQASerTokenLayoutLMPostProcess
,
self
).
__init__
()
label2id_map
,
self
.
id2label_map
=
load_vqa_bio_label_maps
(
class_path
)
self
.
label2id_map_for_draw
=
dict
()
for
key
in
label2id_map
:
if
key
.
startswith
(
"I-"
):
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
"B"
+
key
[
1
:]]
else
:
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
key
]
self
.
id2label_map_for_show
=
dict
()
for
key
in
self
.
label2id_map_for_draw
:
val
=
self
.
label2id_map_for_draw
[
key
]
if
key
==
"O"
:
self
.
id2label_map_for_show
[
val
]
=
key
if
key
.
startswith
(
"B-"
)
or
key
.
startswith
(
"I-"
):
self
.
id2label_map_for_show
[
val
]
=
key
[
2
:]
else
:
self
.
id2label_map_for_show
[
val
]
=
key
def
__call__
(
self
,
preds
,
batch
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
if
batch
is
not
None
:
return
self
.
_metric
(
preds
,
batch
[
1
])
else
:
return
self
.
_infer
(
preds
,
**
kwargs
)
def
_metric
(
self
,
preds
,
label
):
pred_idxs
=
preds
.
argmax
(
axis
=
2
)
decode_out_list
=
[[]
for
_
in
range
(
pred_idxs
.
shape
[
0
])]
label_decode_out_list
=
[[]
for
_
in
range
(
pred_idxs
.
shape
[
0
])]
for
i
in
range
(
pred_idxs
.
shape
[
0
]):
for
j
in
range
(
pred_idxs
.
shape
[
1
]):
if
label
[
i
,
j
]
!=
-
100
:
label_decode_out_list
[
i
].
append
(
self
.
id2label_map
[
label
[
i
,
j
]])
decode_out_list
[
i
].
append
(
self
.
id2label_map
[
pred_idxs
[
i
,
j
]])
return
decode_out_list
,
label_decode_out_list
def
_infer
(
self
,
preds
,
attention_masks
,
segment_offset_ids
,
ocr_infos
):
results
=
[]
for
pred
,
attention_mask
,
segment_offset_id
,
ocr_info
in
zip
(
preds
,
attention_masks
,
segment_offset_ids
,
ocr_infos
):
pred
=
np
.
argmax
(
pred
,
axis
=
1
)
pred
=
[
self
.
id2label_map
[
idx
]
for
idx
in
pred
]
for
idx
in
range
(
len
(
segment_offset_id
)):
if
idx
==
0
:
start_id
=
0
else
:
start_id
=
segment_offset_id
[
idx
-
1
]
end_id
=
segment_offset_id
[
idx
]
curr_pred
=
pred
[
start_id
:
end_id
]
curr_pred
=
[
self
.
label2id_map_for_draw
[
p
]
for
p
in
curr_pred
]
if
len
(
curr_pred
)
<=
0
:
pred_id
=
0
else
:
counts
=
np
.
bincount
(
curr_pred
)
pred_id
=
np
.
argmax
(
counts
)
ocr_info
[
idx
][
"pred_id"
]
=
int
(
pred_id
)
ocr_info
[
idx
][
"pred"
]
=
self
.
id2label_map_for_show
[
int
(
pred_id
)]
results
.
append
(
ocr_info
)
return
results
ppocr/utils/save_load.py
View file @
41a1b292
...
...
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_model
(
config
,
model
,
optimizer
=
None
):
def
load_model
(
config
,
model
,
optimizer
=
None
,
model_type
=
'det'
):
"""
load model from checkpoint or pretrained_model
"""
...
...
@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
checkpoints
=
global_config
.
get
(
'checkpoints'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
best_model_dict
=
{}
if
model_type
==
'vqa'
:
checkpoints
=
config
[
'Architecture'
][
'Backbone'
][
'checkpoints'
]
# load vqa method metric
if
checkpoints
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
checkpoints
,
'metric.states'
)):
with
open
(
os
.
path
.
join
(
checkpoints
,
'metric.states'
),
'rb'
)
as
f
:
states_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
best_model_dict
=
states_dict
.
get
(
'best_model_dict'
,
{})
if
'epoch'
in
states_dict
:
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
if
optimizer
is
not
None
:
if
checkpoints
[
-
1
]
in
[
'/'
,
'
\\
'
]:
checkpoints
=
checkpoints
[:
-
1
]
if
os
.
path
.
exists
(
checkpoints
+
'.pdopt'
):
optim_dict
=
paddle
.
load
(
checkpoints
+
'.pdopt'
)
optimizer
.
set_state_dict
(
optim_dict
)
else
:
logger
.
warning
(
"{}.pdopt is not exists, params of optimizer is not loaded"
.
format
(
checkpoints
))
return
best_model_dict
if
checkpoints
:
if
checkpoints
.
endswith
(
'.pdparams'
):
checkpoints
=
checkpoints
.
replace
(
'.pdparams'
,
''
)
...
...
@@ -111,13 +138,16 @@ def load_pretrained_params(model, path):
params
=
paddle
.
load
(
path
+
'.pdparams'
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()
)
:
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
for
k1
in
params
.
keys
():
if
k1
not
in
state_dict
.
keys
(
):
logger
.
warning
(
"The pretrained params {} not in model"
.
format
(
k1
))
else
:
logger
.
warning
(
"The shape of model params {} {} not matched with loaded params {} {} !"
.
format
(
k1
,
state_dict
[
k1
].
shape
,
k2
,
params
[
k2
].
shape
))
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k1
].
shape
):
new_state_dict
[
k1
]
=
params
[
k1
]
else
:
logger
.
warning
(
"The shape of model params {} {} not matched with loaded params {} {} !"
.
format
(
k1
,
state_dict
[
k1
].
shape
,
k1
,
params
[
k1
].
shape
))
model
.
set_state_dict
(
new_state_dict
)
logger
.
info
(
"load pretrain successful from {}"
.
format
(
path
))
return
model
...
...
@@ -127,6 +157,7 @@ def save_model(model,
optimizer
,
model_path
,
logger
,
config
,
is_best
=
False
,
prefix
=
'ppocr'
,
**
kwargs
):
...
...
@@ -135,13 +166,20 @@ def save_model(model,
"""
_mkdir_if_not_exist
(
model_path
,
logger
)
model_prefix
=
os
.
path
.
join
(
model_path
,
prefix
)
paddle
.
save
(
model
.
state_dict
(),
model_prefix
+
'.pdparams'
)
paddle
.
save
(
optimizer
.
state_dict
(),
model_prefix
+
'.pdopt'
)
if
config
[
'Architecture'
][
"model_type"
]
!=
'vqa'
:
paddle
.
save
(
model
.
state_dict
(),
model_prefix
+
'.pdparams'
)
metric_prefix
=
model_prefix
else
:
if
config
[
'Global'
][
'distributed'
]:
model
.
_layers
.
backbone
.
model
.
save_pretrained
(
model_prefix
)
else
:
model
.
backbone
.
model
.
save_pretrained
(
model_prefix
)
metric_prefix
=
os
.
path
.
join
(
model_prefix
,
'metric'
)
# save metric and config
with
open
(
model_prefix
+
'.states'
,
'wb'
)
as
f
:
pickle
.
dump
(
kwargs
,
f
,
protocol
=
2
)
if
is_best
:
with
open
(
metric_prefix
+
'.states'
,
'wb'
)
as
f
:
pickle
.
dump
(
kwargs
,
f
,
protocol
=
2
)
logger
.
info
(
'save best model is to {}'
.
format
(
model_prefix
))
else
:
logger
.
info
(
"save model in {}"
.
format
(
model_prefix
))
ppocr/utils/utility.py
View file @
41a1b292
...
...
@@ -16,6 +16,9 @@ import logging
import
os
import
imghdr
import
cv2
import
random
import
numpy
as
np
import
paddle
def
print_dict
(
d
,
logger
,
delimiter
=
0
):
...
...
@@ -77,4 +80,28 @@ def check_and_read_gif(img_path):
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_GRAY2RGB
)
imgvalue
=
frame
[:,
:,
::
-
1
]
return
imgvalue
,
True
return
None
,
False
\ No newline at end of file
return
None
,
False
def
load_vqa_bio_label_maps
(
label_map_path
):
with
open
(
label_map_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
[
line
.
strip
()
for
line
in
lines
]
if
"O"
not
in
lines
:
lines
.
insert
(
0
,
"O"
)
labels
=
[]
for
line
in
lines
:
if
line
==
"O"
:
labels
.
append
(
"O"
)
else
:
labels
.
append
(
"B-"
+
line
)
labels
.
append
(
"I-"
+
line
)
label2id_map
=
{
label
:
idx
for
idx
,
label
in
enumerate
(
labels
)}
id2label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
labels
)}
return
label2id_map
,
id2label_map
def
set_seed
(
seed
=
1024
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment