Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
c9e1077d
Commit
c9e1077d
authored
Aug 30, 2021
by
tink2123
Browse files
polish code
parent
59cc4efd
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
19 deletions
+95
-19
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+83
-4
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+7
-10
tools/program.py
tools/program.py
+5
-5
No files found.
ppocr/postprocess/rec_postprocess.py
View file @
c9e1077d
...
...
@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unkonwn
=
"UNKNOWN"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
+
[
self
.
unkonwn
]
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
...
...
@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds
=
preds
[
"rec_pred"
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
...
...
@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
end_idx
=
self
.
get_beg_end_flag_idx
(
"eos"
)
return
[
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"sos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"eos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
%
beg_or_end
return
idx
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
[
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
)))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx
=
preds
[
"rec_pred"
]
if
isinstance
(
preds_idx
,
paddle
.
Tensor
):
preds_idx
=
preds_idx
.
numpy
()
if
"rec_pred_scores"
in
preds
:
preds_idx
=
preds
[
"rec_pred"
]
preds_prob
=
preds
[
"rec_pred_scores"
]
else
:
preds_idx
=
preds
[
"rec_pred"
].
argmax
(
axis
=
2
)
preds_prob
=
preds
[
"rec_pred"
].
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
ppocr/utils/save_load.py
View file @
c9e1077d
...
...
@@ -105,16 +105,13 @@ def load_dygraph_params(config, model, logger, optimizer):
params
=
paddle
.
load
(
pm
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
# for k1, k2 in zip(state_dict.keys(), params.keys()):
for
k1
in
state_dict
.
keys
():
if
k1
not
in
params
:
continue
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k1
].
shape
):
new_state_dict
[
k1
]
=
params
[
k1
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k1
}
{
params
[
k1
].
shape
}
!"
)
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
...
...
tools/program.py
View file @
c9e1077d
...
...
@@ -211,11 +211,10 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
# if use_srn or model_type == 'table' or algorithm == "ASTER":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds
=
model
(
images
,
data
=
batch
[
1
:])
if
use_srn
or
model_type
==
'table'
or
model_type
==
"seed"
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
state_dict
=
model
.
state_dict
()
# for key in state_dict:
# print(key)
...
...
@@ -415,6 +414,7 @@ def preprocess(is_train=False):
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
print
(
"log has save in {}/train.log"
.
format
(
save_model_dir
))
else
:
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment