Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
d9c28128
Commit
d9c28128
authored
Oct 09, 2021
by
LDOUBLEV
Browse files
fix multi-inputs
parent
4e0fcd6e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
17 deletions
+11
-17
ppocr/modeling/backbones/kie_unet_sdmgr.py
ppocr/modeling/backbones/kie_unet_sdmgr.py
+4
-12
tools/infer_kie.py
tools/infer_kie.py
+4
-2
tools/program.py
tools/program.py
+3
-3
No files found.
ppocr/modeling/backbones/kie_unet_sdmgr.py
View file @
d9c28128
...
@@ -167,20 +167,12 @@ class Kie_backbone(nn.Layer):
...
@@ -167,20 +167,12 @@ class Kie_backbone(nn.Layer):
gt_bboxes
[
i
,
:
num
,
...],
dtype
=
'float32'
))
gt_bboxes
[
i
,
:
num
,
...],
dtype
=
'float32'
))
return
img
,
temp_relations
,
temp_texts
,
temp_gt_bboxes
return
img
,
temp_relations
,
temp_texts
,
temp_gt_bboxes
def
forward
(
self
,
inputs
):
def
forward
(
self
,
images
,
inputs
):
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
=
inputs
[
0
],
inputs
[
img
=
images
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
5
],
inputs
[
-
1
]
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
=
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
4
],
inputs
[
-
1
]
img
,
relations
,
texts
,
gt_bboxes
=
self
.
pre_process
(
img
,
relations
,
texts
,
gt_bboxes
=
self
.
pre_process
(
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
)
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
)
# for i in range(4):
# img_t = (img[i].numpy().transpose([1, 2, 0]) * 255.0).astype('uint8')
# img_t = img_t.copy()
# gt_bboxes_t = gt_bboxes[i].cpu().numpy()
# box = gt_bboxes_t.astype(np.int32).reshape((-1, 1, 2))
# cv2.polylines(img_t, [box], True, color=(255, 255, 0), thickness=1)
# cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t)
# # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t * 255.0)
# exit()
x
=
self
.
img_feat
(
img
)
x
=
self
.
img_feat
(
img
)
boxes
,
rois_num
=
self
.
bbox2roi
(
gt_bboxes
)
boxes
,
rois_num
=
self
.
bbox2roi
(
gt_bboxes
)
feats
=
paddle
.
fluid
.
layers
.
roi_align
(
feats
=
paddle
.
fluid
.
layers
.
roi_align
(
...
...
tools/infer_kie.py
View file @
d9c28128
...
@@ -80,7 +80,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
...
@@ -80,7 +80,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img
=
np
.
ones
((
h
,
w
*
3
,
3
),
dtype
=
np
.
uint8
)
*
255
vis_img
=
np
.
ones
((
h
,
w
*
3
,
3
),
dtype
=
np
.
uint8
)
*
255
vis_img
[:,
:
w
]
=
img
vis_img
[:,
:
w
]
=
img
vis_img
[:,
w
:]
=
pred_img
vis_img
[:,
w
:]
=
pred_img
save_kie_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/kie_results/"
save_kie_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/kie_results/"
if
not
os
.
path
.
exists
(
save_kie_path
):
if
not
os
.
path
.
exists
(
save_kie_path
):
os
.
makedirs
(
save_kie_path
)
os
.
makedirs
(
save_kie_path
)
save_path
=
os
.
path
.
join
(
save_kie_path
,
str
(
count
)
+
".png"
)
save_path
=
os
.
path
.
join
(
save_kie_path
,
str
(
count
)
+
".png"
)
...
@@ -128,7 +129,8 @@ def main():
...
@@ -128,7 +129,8 @@ def main():
batch_pred
[
i
]
=
paddle
.
to_tensor
(
batch_pred
[
i
]
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
expand_dims
(
batch
[
i
],
axis
=
0
))
batch
[
i
],
axis
=
0
))
node
,
edge
=
model
(
batch_pred
)
node
,
edge
=
model
(
batch
[
0
],
batch
[
1
:])
node
=
F
.
softmax
(
node
,
-
1
)
node
=
F
.
softmax
(
node
,
-
1
)
draw_kie_result
(
batch
,
node
,
idx_to_cls
,
index
)
draw_kie_result
(
batch
,
node
,
idx_to_cls
,
index
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
...
...
tools/program.py
View file @
d9c28128
...
@@ -197,7 +197,7 @@ def train(config,
...
@@ -197,7 +197,7 @@ def train(config,
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input
=
config
[
'Architecture'
][
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SDMGR"
]
try
:
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
except
:
...
@@ -230,7 +230,7 @@ def train(config,
...
@@ -230,7 +230,7 @@ def train(config,
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
batch
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
avg_loss
.
backward
()
avg_loss
.
backward
()
...
@@ -379,7 +379,7 @@ def eval(model,
...
@@ -379,7 +379,7 @@ def eval(model,
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
batch
)
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment