Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
0c177c65
Commit
0c177c65
authored
Dec 16, 2020
by
weishengyu
Browse files
Merge
https://github.com/PaddlePaddle/PaddleOCR
into dygraph
parents
78d3349e
fbf66516
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
15 additions
and
11 deletions
+15
-11
doc/imgs_results/det_res_00018069.jpg
doc/imgs_results/det_res_00018069.jpg
+0
-0
doc/imgs_results/det_res_2.jpg
doc/imgs_results/det_res_2.jpg
+0
-0
doc/imgs_results/det_res_22.jpg
doc/imgs_results/det_res_22.jpg
+0
-0
ppocr/losses/det_db_loss.py
ppocr/losses/det_db_loss.py
+4
-3
ppocr/modeling/heads/det_db_head.py
ppocr/modeling/heads/det_db_head.py
+2
-2
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+4
-2
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+0
-1
tools/infer/predict_det.py
tools/infer/predict_det.py
+5
-3
No files found.
doc/imgs_results/det_res_00018069.jpg
0 → 100644
View file @
0c177c65
85.9 KB
doc/imgs_results/det_res_2.jpg
deleted
100644 → 0
View file @
78d3349e
77.3 KB
doc/imgs_results/det_res_22.jpg
deleted
100644 → 0
View file @
78d3349e
76.3 KB
ppocr/losses/det_db_loss.py
View file @
0c177c65
...
...
@@ -47,11 +47,12 @@ class DBLoss(nn.Layer):
negative_ratio
=
ohem_ratio
)
def
forward
(
self
,
predicts
,
labels
):
predict_maps
=
predicts
[
'maps'
]
label_threshold_map
,
label_threshold_mask
,
label_shrink_map
,
label_shrink_mask
=
labels
[
1
:]
shrink_maps
=
predicts
[:,
0
,
:,
:]
threshold_maps
=
predicts
[:,
1
,
:,
:]
binary_maps
=
predicts
[:,
2
,
:,
:]
shrink_maps
=
predict
_map
s
[:,
0
,
:,
:]
threshold_maps
=
predict
_map
s
[:,
1
,
:,
:]
binary_maps
=
predict
_map
s
[:,
2
,
:,
:]
loss_shrink_maps
=
self
.
bce_loss
(
shrink_maps
,
label_shrink_map
,
label_shrink_mask
)
...
...
ppocr/modeling/heads/det_db_head.py
View file @
0c177c65
...
...
@@ -120,9 +120,9 @@ class DBHead(nn.Layer):
def
forward
(
self
,
x
):
shrink_maps
=
self
.
binarize
(
x
)
if
not
self
.
training
:
return
shrink_maps
return
{
'maps'
:
shrink_maps
}
threshold_maps
=
self
.
thresh
(
x
)
binary_maps
=
self
.
step_function
(
shrink_maps
,
threshold_maps
)
y
=
paddle
.
concat
([
shrink_maps
,
threshold_maps
,
binary_maps
],
axis
=
1
)
return
y
return
{
'maps'
:
y
}
ppocr/postprocess/db_postprocess.py
View file @
0c177c65
...
...
@@ -40,7 +40,8 @@ class DBPostProcess(object):
self
.
max_candidates
=
max_candidates
self
.
unclip_ratio
=
unclip_ratio
self
.
min_size
=
3
self
.
dilation_kernel
=
None
if
not
use_dilation
else
np
.
array
([[
1
,
1
],
[
1
,
1
]])
self
.
dilation_kernel
=
None
if
not
use_dilation
else
np
.
array
(
[[
1
,
1
],
[
1
,
1
]])
def
boxes_from_bitmap
(
self
,
pred
,
_bitmap
,
dest_width
,
dest_height
):
'''
...
...
@@ -132,7 +133,8 @@ class DBPostProcess(object):
cv2
.
fillPoly
(
mask
,
box
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
__call__
(
self
,
pred
,
shape_list
):
def
__call__
(
self
,
outs_dict
,
shape_list
):
pred
=
outs_dict
[
'maps'
]
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
pred
[:,
0
,
:,
:]
...
...
ppocr/utils/save_load.py
View file @
0c177c65
...
...
@@ -102,7 +102,6 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict
=
states_dict
.
get
(
'best_model_dict'
,
{})
if
'epoch'
in
states_dict
:
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
best_model_dict
[
'start_epoch'
]
=
best_model_dict
[
'best_epoch'
]
+
1
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
elif
pretrained_model
:
...
...
tools/infer/predict_det.py
View file @
0c177c65
...
...
@@ -65,12 +65,12 @@ class TextDetector(object):
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
True
elif
self
.
det_algorithm
==
"EAST"
:
postprocess_params
[
'name'
]
=
'EASTPostProcess'
postprocess_params
[
'name'
]
=
'EASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_east_score_thresh
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
elif
self
.
det_algorithm
==
"SAST"
:
postprocess_params
[
'name'
]
=
'SASTPostProcess'
postprocess_params
[
'name'
]
=
'SASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_sast_score_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_sast_nms_thresh
self
.
det_sast_polygon
=
args
.
det_sast_polygon
...
...
@@ -177,8 +177,10 @@ class TextDetector(object):
preds
[
'f_score'
]
=
outputs
[
1
]
preds
[
'f_tco'
]
=
outputs
[
2
]
preds
[
'f_tvo'
]
=
outputs
[
3
]
elif
self
.
det_algorithm
==
'DB'
:
preds
[
'maps'
]
=
outputs
[
0
]
else
:
preds
=
outputs
[
0
]
raise
NotImplementedError
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment