Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
e93735a2
Unverified
Commit
e93735a2
authored
Jun 15, 2021
by
MissPenguin
Committed by
GitHub
Jun 15, 2021
Browse files
Merge pull request #3083 from WenmuZhou/table1
[DO NOT MERGE]Table
parents
6127aad9
b2260182
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
483 additions
and
6 deletions
+483
-6
test/table/tablepyxl/__init__.py
test/table/tablepyxl/__init__.py
+13
-0
test/table/tablepyxl/style.py
test/table/tablepyxl/style.py
+283
-0
test/table/tablepyxl/tablepyxl.py
test/table/tablepyxl/tablepyxl.py
+118
-0
test/utility.py
test/utility.py
+54
-0
tools/infer/predict_det.py
tools/infer/predict_det.py
+1
-1
tools/infer/predict_system.py
tools/infer/predict_system.py
+7
-3
tools/infer/utility.py
tools/infer/utility.py
+7
-2
No files found.
test/table/tablepyxl/__init__.py
0 → 100644
View file @
e93735a2
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
test/table/tablepyxl/style.py
0 → 100644
View file @
e93735a2
# This is where we handle translating css styles into openpyxl styles
# and cascading those from parent to child in the dom.
from
openpyxl.cell
import
cell
from
openpyxl.styles
import
Font
,
Alignment
,
PatternFill
,
NamedStyle
,
Border
,
Side
,
Color
from
openpyxl.styles.fills
import
FILL_SOLID
from
openpyxl.styles.numbers
import
FORMAT_CURRENCY_USD_SIMPLE
,
FORMAT_PERCENTAGE
from
openpyxl.styles.colors
import
BLACK
FORMAT_DATE_MMDDYYYY
=
'mm/dd/yyyy'
def
colormap
(
color
):
"""
Convenience for looking up known colors
"""
cmap
=
{
'black'
:
BLACK
}
return
cmap
.
get
(
color
,
color
)
def
style_string_to_dict
(
style
):
"""
Convert css style string to a python dictionary
"""
def
clean_split
(
string
,
delim
):
return
(
s
.
strip
()
for
s
in
string
.
split
(
delim
))
styles
=
[
clean_split
(
s
,
":"
)
for
s
in
style
.
split
(
";"
)
if
":"
in
s
]
return
dict
(
styles
)
def
get_side
(
style
,
name
):
return
{
'border_style'
:
style
.
get
(
'border-{}-style'
.
format
(
name
)),
'color'
:
colormap
(
style
.
get
(
'border-{}-color'
.
format
(
name
)))}
known_styles
=
{}
def
style_dict_to_named_style
(
style_dict
,
number_format
=
None
):
"""
Change css style (stored in a python dictionary) to openpyxl NamedStyle
"""
style_and_format_string
=
str
({
'style_dict'
:
style_dict
,
'parent'
:
style_dict
.
parent
,
'number_format'
:
number_format
,
})
if
style_and_format_string
not
in
known_styles
:
# Font
font
=
Font
(
bold
=
style_dict
.
get
(
'font-weight'
)
==
'bold'
,
color
=
style_dict
.
get_color
(
'color'
,
None
),
size
=
style_dict
.
get
(
'font-size'
))
# Alignment
alignment
=
Alignment
(
horizontal
=
style_dict
.
get
(
'text-align'
,
'general'
),
vertical
=
style_dict
.
get
(
'vertical-align'
),
wrap_text
=
style_dict
.
get
(
'white-space'
,
'nowrap'
)
==
'normal'
)
# Fill
bg_color
=
style_dict
.
get_color
(
'background-color'
)
fg_color
=
style_dict
.
get_color
(
'foreground-color'
,
Color
())
fill_type
=
style_dict
.
get
(
'fill-type'
)
if
bg_color
and
bg_color
!=
'transparent'
:
fill
=
PatternFill
(
fill_type
=
fill_type
or
FILL_SOLID
,
start_color
=
bg_color
,
end_color
=
fg_color
)
else
:
fill
=
PatternFill
()
# Border
border
=
Border
(
left
=
Side
(
**
get_side
(
style_dict
,
'left'
)),
right
=
Side
(
**
get_side
(
style_dict
,
'right'
)),
top
=
Side
(
**
get_side
(
style_dict
,
'top'
)),
bottom
=
Side
(
**
get_side
(
style_dict
,
'bottom'
)),
diagonal
=
Side
(
**
get_side
(
style_dict
,
'diagonal'
)),
diagonal_direction
=
None
,
outline
=
Side
(
**
get_side
(
style_dict
,
'outline'
)),
vertical
=
None
,
horizontal
=
None
)
name
=
'Style {}'
.
format
(
len
(
known_styles
)
+
1
)
pyxl_style
=
NamedStyle
(
name
=
name
,
font
=
font
,
fill
=
fill
,
alignment
=
alignment
,
border
=
border
,
number_format
=
number_format
)
known_styles
[
style_and_format_string
]
=
pyxl_style
return
known_styles
[
style_and_format_string
]
class
StyleDict
(
dict
):
"""
It's like a dictionary, but it looks for items in the parent dictionary
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
parent
=
kwargs
.
pop
(
'parent'
,
None
)
super
(
StyleDict
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
__getitem__
(
self
,
item
):
if
item
in
self
:
return
super
(
StyleDict
,
self
).
__getitem__
(
item
)
elif
self
.
parent
:
return
self
.
parent
[
item
]
else
:
raise
KeyError
(
'{} not found'
.
format
(
item
))
def
__hash__
(
self
):
return
hash
(
tuple
([(
k
,
self
.
get
(
k
))
for
k
in
self
.
_keys
()]))
# Yielding the keys avoids creating unnecessary data structures
# and happily works with both python2 and python3 where the
# .keys() method is a dictionary_view in python3 and a list in python2.
def
_keys
(
self
):
yielded
=
set
()
for
k
in
self
.
keys
():
yielded
.
add
(
k
)
yield
k
if
self
.
parent
:
for
k
in
self
.
parent
.
_keys
():
if
k
not
in
yielded
:
yielded
.
add
(
k
)
yield
k
def
get
(
self
,
k
,
d
=
None
):
try
:
return
self
[
k
]
except
KeyError
:
return
d
def
get_color
(
self
,
k
,
d
=
None
):
"""
Strip leading # off colors if necessary
"""
color
=
self
.
get
(
k
,
d
)
if
hasattr
(
color
,
'startswith'
)
and
color
.
startswith
(
'#'
):
color
=
color
[
1
:]
if
len
(
color
)
==
3
:
# Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
color
=
''
.
join
(
2
*
c
for
c
in
color
)
return
color
class
Element
(
object
):
"""
Our base class for representing an html element along with a cascading style.
The element is created along with a parent so that the StyleDict that we store
can point to the parent's StyleDict.
"""
def
__init__
(
self
,
element
,
parent
=
None
):
self
.
element
=
element
self
.
number_format
=
None
parent_style
=
parent
.
style_dict
if
parent
else
None
self
.
style_dict
=
StyleDict
(
style_string_to_dict
(
element
.
get
(
'style'
,
''
)),
parent
=
parent_style
)
self
.
_style_cache
=
None
def
style
(
self
):
"""
Turn the css styles for this element into an openpyxl NamedStyle.
"""
if
not
self
.
_style_cache
:
self
.
_style_cache
=
style_dict_to_named_style
(
self
.
style_dict
,
number_format
=
self
.
number_format
)
return
self
.
_style_cache
def
get_dimension
(
self
,
dimension_key
):
"""
Extracts the dimension from the style dict of the Element and returns it as a float.
"""
dimension
=
self
.
style_dict
.
get
(
dimension_key
)
if
dimension
:
if
dimension
[
-
2
:]
in
[
'px'
,
'em'
,
'pt'
,
'in'
,
'cm'
]:
dimension
=
dimension
[:
-
2
]
dimension
=
float
(
dimension
)
return
dimension
class
Table
(
Element
):
"""
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
"""
def
__init__
(
self
,
table
):
"""
takes an html table object (from lxml)
"""
super
(
Table
,
self
).
__init__
(
table
)
table_head
=
table
.
find
(
'thead'
)
self
.
head
=
TableHead
(
table_head
,
parent
=
self
)
if
table_head
is
not
None
else
None
table_body
=
table
.
find
(
'tbody'
)
self
.
body
=
TableBody
(
table_body
if
table_body
is
not
None
else
table
,
parent
=
self
)
class
TableHead
(
Element
):
"""
This class maps to the `<th>` element of the html table.
"""
def
__init__
(
self
,
head
,
parent
=
None
):
super
(
TableHead
,
self
).
__init__
(
head
,
parent
=
parent
)
self
.
rows
=
[
TableRow
(
tr
,
parent
=
self
)
for
tr
in
head
.
findall
(
'tr'
)]
class
TableBody
(
Element
):
"""
This class maps to the `<tbody>` element of the html table.
"""
def
__init__
(
self
,
body
,
parent
=
None
):
super
(
TableBody
,
self
).
__init__
(
body
,
parent
=
parent
)
self
.
rows
=
[
TableRow
(
tr
,
parent
=
self
)
for
tr
in
body
.
findall
(
'tr'
)]
class
TableRow
(
Element
):
"""
This class maps to the `<tr>` element of the html table.
"""
def
__init__
(
self
,
tr
,
parent
=
None
):
super
(
TableRow
,
self
).
__init__
(
tr
,
parent
=
parent
)
self
.
cells
=
[
TableCell
(
cell
,
parent
=
self
)
for
cell
in
tr
.
findall
(
'th'
)
+
tr
.
findall
(
'td'
)]
def
element_to_string
(
el
):
return
_element_to_string
(
el
).
strip
()
def
_element_to_string
(
el
):
string
=
''
for
x
in
el
.
iterchildren
():
string
+=
'
\n
'
+
_element_to_string
(
x
)
text
=
el
.
text
.
strip
()
if
el
.
text
else
''
tail
=
el
.
tail
.
strip
()
if
el
.
tail
else
''
return
text
+
string
+
'
\n
'
+
tail
class
TableCell
(
Element
):
"""
This class maps to the `<td>` element of the html table.
"""
CELL_TYPES
=
{
'TYPE_STRING'
,
'TYPE_FORMULA'
,
'TYPE_NUMERIC'
,
'TYPE_BOOL'
,
'TYPE_CURRENCY'
,
'TYPE_PERCENTAGE'
,
'TYPE_NULL'
,
'TYPE_INLINE'
,
'TYPE_ERROR'
,
'TYPE_FORMULA_CACHE_STRING'
,
'TYPE_INTEGER'
}
def
__init__
(
self
,
cell
,
parent
=
None
):
super
(
TableCell
,
self
).
__init__
(
cell
,
parent
=
parent
)
self
.
value
=
element_to_string
(
cell
)
self
.
number_format
=
self
.
get_number_format
()
def
data_type
(
self
):
cell_types
=
self
.
CELL_TYPES
&
set
(
self
.
element
.
get
(
'class'
,
''
).
split
())
if
cell_types
:
if
'TYPE_FORMULA'
in
cell_types
:
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
cell_type
=
'TYPE_FORMULA'
elif
cell_types
&
{
'TYPE_CURRENCY'
,
'TYPE_INTEGER'
,
'TYPE_PERCENTAGE'
}:
cell_type
=
'TYPE_NUMERIC'
else
:
cell_type
=
cell_types
.
pop
()
else
:
cell_type
=
'TYPE_STRING'
return
getattr
(
cell
,
cell_type
)
def
get_number_format
(
self
):
if
'TYPE_CURRENCY'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
FORMAT_CURRENCY_USD_SIMPLE
if
'TYPE_INTEGER'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
'#,##0'
if
'TYPE_PERCENTAGE'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
FORMAT_PERCENTAGE
if
'TYPE_DATE'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
FORMAT_DATE_MMDDYYYY
if
self
.
data_type
()
==
cell
.
TYPE_NUMERIC
:
try
:
int
(
self
.
value
)
except
ValueError
:
return
'#,##0.##'
else
:
return
'#,##0'
def
format
(
self
,
cell
):
cell
.
style
=
self
.
style
()
data_type
=
self
.
data_type
()
if
data_type
:
cell
.
data_type
=
data_type
\ No newline at end of file
test/table/tablepyxl/tablepyxl.py
0 → 100644
View file @
e93735a2
# Do imports like python3 so our package works for 2 and 3
from
__future__
import
absolute_import
from
lxml
import
html
from
openpyxl
import
Workbook
from
openpyxl.utils
import
get_column_letter
from
premailer
import
Premailer
from
tablepyxl.style
import
Table
def
string_to_int
(
s
):
if
s
.
isdigit
():
return
int
(
s
)
return
0
def
get_Tables
(
doc
):
tree
=
html
.
fromstring
(
doc
)
comments
=
tree
.
xpath
(
'//comment()'
)
for
comment
in
comments
:
comment
.
drop_tag
()
return
[
Table
(
table
)
for
table
in
tree
.
xpath
(
'//table'
)]
def
write_rows
(
worksheet
,
elem
,
row
,
column
=
1
):
"""
Writes every tr child element of elem to a row in the worksheet
returns the next row after all rows are written
"""
from
openpyxl.cell.cell
import
MergedCell
initial_column
=
column
for
table_row
in
elem
.
rows
:
for
table_cell
in
table_row
.
cells
:
cell
=
worksheet
.
cell
(
row
=
row
,
column
=
column
)
while
isinstance
(
cell
,
MergedCell
):
column
+=
1
cell
=
worksheet
.
cell
(
row
=
row
,
column
=
column
)
colspan
=
string_to_int
(
table_cell
.
element
.
get
(
"colspan"
,
"1"
))
rowspan
=
string_to_int
(
table_cell
.
element
.
get
(
"rowspan"
,
"1"
))
if
rowspan
>
1
or
colspan
>
1
:
worksheet
.
merge_cells
(
start_row
=
row
,
start_column
=
column
,
end_row
=
row
+
rowspan
-
1
,
end_column
=
column
+
colspan
-
1
)
cell
.
value
=
table_cell
.
value
table_cell
.
format
(
cell
)
min_width
=
table_cell
.
get_dimension
(
'min-width'
)
max_width
=
table_cell
.
get_dimension
(
'max-width'
)
if
colspan
==
1
:
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
# cell in the same column (i.e. width of A2 = width of A1)
width
=
max
(
worksheet
.
column_dimensions
[
get_column_letter
(
column
)].
width
or
0
,
len
(
table_cell
.
value
)
+
2
)
if
max_width
and
width
>
max_width
:
width
=
max_width
elif
min_width
and
width
<
min_width
:
width
=
min_width
worksheet
.
column_dimensions
[
get_column_letter
(
column
)].
width
=
width
column
+=
colspan
row
+=
1
column
=
initial_column
return
row
def
table_to_sheet
(
table
,
wb
):
"""
Takes a table and workbook and writes the table to a new sheet.
The sheet title will be the same as the table attribute name.
"""
ws
=
wb
.
create_sheet
(
title
=
table
.
element
.
get
(
'name'
))
insert_table
(
table
,
ws
,
1
,
1
)
def
document_to_workbook
(
doc
,
wb
=
None
,
base_url
=
None
):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document.
The workbook is returned
"""
if
not
wb
:
wb
=
Workbook
()
wb
.
remove
(
wb
.
active
)
inline_styles_doc
=
Premailer
(
doc
,
base_url
=
base_url
,
remove_classes
=
False
).
transform
()
tables
=
get_Tables
(
inline_styles_doc
)
for
table
in
tables
:
table_to_sheet
(
table
,
wb
)
return
wb
def
document_to_xl
(
doc
,
filename
,
base_url
=
None
):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document. The workbook is written out to a file called filename
"""
wb
=
document_to_workbook
(
doc
,
base_url
=
base_url
)
wb
.
save
(
filename
)
def
insert_table
(
table
,
worksheet
,
column
,
row
):
if
table
.
head
:
row
=
write_rows
(
worksheet
,
table
.
head
,
row
,
column
)
if
table
.
body
:
row
=
write_rows
(
worksheet
,
table
.
body
,
row
,
column
)
def
insert_table_at_cell
(
table
,
cell
):
"""
Inserts a table at the location of an openpyxl Cell object.
"""
ws
=
cell
.
parent
column
,
row
=
cell
.
column
,
cell
.
row
insert_table
(
table
,
ws
,
column
,
row
)
\ No newline at end of file
test/utility.py
0 → 100644
View file @
e93735a2
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
PIL
import
Image
import
numpy
as
np
from
tools.infer.utility
import
draw_ocr_box_txt
,
init_args
as
infer_args
def
init_args
():
parser
=
infer_args
()
# params for output
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./output/table'
)
# params for table structure
parser
.
add_argument
(
"--structure_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--structure_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--structure_char_type"
,
type
=
str
,
default
=
'en'
)
parser
.
add_argument
(
"--structure_char_dict_path"
,
type
=
str
,
default
=
"../ppocr/utils/dict/table_structure_dict.txt"
)
return
parser
def
parse_args
():
parser
=
init_args
()
return
parser
.
parse_args
()
def
draw_result
(
image
,
result
,
font_path
):
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
boxes
,
txts
,
scores
=
[],
[],
[]
for
region
in
result
:
if
region
[
'type'
]
==
'Table'
:
pass
elif
region
[
'type'
]
==
'Figure'
:
pass
else
:
for
box
,
rec_res
in
zip
(
region
[
'res'
][
0
],
region
[
'res'
][
1
]):
boxes
.
append
(
np
.
array
(
box
).
reshape
(
-
1
,
2
))
txts
.
append
(
rec_res
[
0
])
scores
.
append
(
rec_res
[
1
])
im_show
=
draw_ocr_box_txt
(
image
,
boxes
,
txts
,
scores
,
font_path
=
font_path
,
drop_score
=
0
)
return
im_show
\ No newline at end of file
tools/infer/predict_det.py
View file @
e93735a2
...
@@ -43,7 +43,7 @@ class TextDetector(object):
...
@@ -43,7 +43,7 @@ class TextDetector(object):
pre_process_list
=
[{
pre_process_list
=
[{
'DetResizeForTest'
:
{
'DetResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
'limit_type'
:
args
.
det_limit_type
,
}
}
},
{
},
{
'NormalizeImage'
:
{
'NormalizeImage'
:
{
...
...
tools/infer/predict_system.py
View file @
e93735a2
...
@@ -24,6 +24,7 @@ import cv2
...
@@ -24,6 +24,7 @@ import cv2
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
logging
from
PIL
import
Image
from
PIL
import
Image
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_rec
as
predict_rec
...
@@ -38,6 +39,9 @@ logger = get_logger()
...
@@ -38,6 +39,9 @@ logger = get_logger()
class
TextSystem
(
object
):
class
TextSystem
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
if
not
args
.
show_log
:
logger
.
setLevel
(
logging
.
INFO
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
use_angle_cls
=
args
.
use_angle_cls
...
@@ -88,7 +92,7 @@ class TextSystem(object):
...
@@ -88,7 +92,7 @@ class TextSystem(object):
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
if
dt_boxes
is
None
:
...
@@ -104,11 +108,11 @@ class TextSystem(object):
...
@@ -104,11 +108,11 @@ class TextSystem(object):
if
self
.
use_angle_cls
and
cls
:
if
self
.
use_angle_cls
and
cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
)
img_crop_list
)
logger
.
info
(
"cls num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"cls num : {}, elapse : {}"
.
format
(
len
(
img_crop_list
),
elapse
))
len
(
img_crop_list
),
elapse
))
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
info
(
"rec_res num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes
,
filter_rec_res
=
[],
[]
filter_boxes
,
filter_rec_res
=
[],
[]
...
...
tools/infer/utility.py
View file @
e93735a2
...
@@ -109,11 +109,12 @@ def init_args():
...
@@ -109,11 +109,12 @@ def init_args():
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--benchmark"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--benchmark"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--show_log"
,
type
=
str2bool
,
default
=
True
)
return
parser
return
parser
...
@@ -199,6 +200,8 @@ def create_predictor(args, mode, logger):
...
@@ -199,6 +200,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
model_dir
=
args
.
rec_model_dir
elif
mode
==
'structure'
:
model_dir
=
args
.
structure_model_dir
else
:
else
:
model_dir
=
args
.
e2e_model_dir
model_dir
=
args
.
e2e_model_dir
...
@@ -328,7 +331,9 @@ def create_predictor(args, mode, logger):
...
@@ -328,7 +331,9 @@ def create_predictor(args, mode, logger):
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
if
mode
==
'structure'
:
config
.
switch_ir_optim
(
False
)
# create predictor
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_names
=
predictor
.
get_input_names
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment