Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f6bf56d7
"docs/vscode:/vscode.git/clone" did not exist on "af9cab3b6c02a8e8fbbb2551199f77bccfc0e8f3"
Commit
f6bf56d7
authored
Aug 03, 2020
by
Kaushik Shivakumar
Browse files
clean and update exporter
parent
a3ae1258
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
21 deletions
+44
-21
research/object_detection/exporter_lib_tf2_test.py
research/object_detection/exporter_lib_tf2_test.py
+3
-3
research/object_detection/exporter_lib_v2.py
research/object_detection/exporter_lib_v2.py
+41
-18
No files found.
research/object_detection/exporter_lib_tf2_test.py
View file @
f6bf56d7
...
...
@@ -194,7 +194,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
saved_model_path
=
os
.
path
.
join
(
output_directory
,
'saved_model'
)
detect_fn
=
tf
.
saved_model
.
load
(
saved_model_path
)
image
=
self
.
get_dummy_input
(
input_type
)
detections
=
detect_fn
.
signatures
[
'serving_default'
]
(
tf
.
constant
(
image
))
detections
=
detect_fn
(
tf
.
constant
(
image
))
detection_fields
=
fields
.
DetectionResultFields
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_boxes
],
...
...
@@ -232,8 +232,8 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
detect_fn_sig
=
detect_fn
.
signatures
[
'serving_default'
]
image
=
tf
.
constant
(
self
.
get_dummy_input
(
input_type
))
side_input
=
np
.
ones
((
1
,),
dtype
=
np
.
float32
)
#
detections
_one = tf.saved_model.load(saved_model_path)(image, side_input)
detections
=
detect_fn_sig
(
input_tensor
=
image
,
side_inp
=
tf
.
constant
(
side_input
))
detections
=
detect_fn_sig
(
input_tensor
=
image
,
side_inp
=
tf
.
constant
(
side_input
))
detection_fields
=
fields
.
DetectionResultFields
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_boxes
],
...
...
research/object_detection/exporter_lib_v2.py
View file @
f6bf56d7
...
...
@@ -36,15 +36,33 @@ def _decode_tf_example(tf_example_string_tensor):
image_tensor
=
tensor_dict
[
fields
.
InputDataFields
.
image
]
return
image_tensor
def
_zip_side_inputs
(
side_input_shapes
=
""
,
side_input_types
=
""
,
side_input_names
=
""
):
"""Zips the side inputs together.
Args:
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Returns:
a zipped list of side input tuples.
"""
side_input_shapes
=
list
(
map
(
lambda
x
:
eval
(
'['
+
x
+
']'
),
side_input_shapes
.
split
(
"/"
)))
side_input_types
=
list
(
map
(
eval
,
side_input_types
.
split
(
","
)))
return
zip
(
side_input_shapes
,
side_input_types
,
side_input_names
.
split
(
","
))
class
DetectionInferenceModule
(
tf
.
Module
):
"""Detection Inference Module."""
def
__init__
(
self
,
detection_model
,
use_side_inputs
=
False
,
side_input_shapes
=
None
,
side_input_types
=
None
,
side_input_names
=
None
):
zipped_side_inputs
=
None
:
"""Initializes a module for detection.
Args:
...
...
@@ -75,27 +93,26 @@ class DetectionInferenceModule(tf.Module):
return
detections
class
DetectionFromImageModule
(
DetectionInferenceModule
):
"""Detection Inference Module for image inputs."""
def
__init__
(
self
,
detection_model
,
use_side_inputs
=
False
,
side_input_shapes
=
""
,
side_input_types
=
""
,
side_input_names
=
""
):
zipped_side_inputs
=
None
):
"""Initializes a module for detection.
Args:
detection_model: The detection model to use for inference.
"""
self
.
side_input_names
=
side_input_names
self
.
side_input_names
=
[]
sig
=
[
tf
.
TensorSpec
(
shape
=
[
1
,
None
,
None
,
3
],
dtype
=
tf
.
uint8
)]
if
use_side_inputs
:
for
info
in
zip
(
side_input_shapes
.
split
(
"/"
),
side_input_types
.
split
(
","
),
side_input_names
.
split
(
","
)):
sig
.
append
(
tf
.
TensorSpec
(
shape
=
eval
(
"["
+
info
[
0
]
+
"]"
),
dtype
=
eval
(
info
[
1
]),
for
info
in
zipped_side_inputs
:
self
.
side_input_names
.
append
(
info
[
2
])
sig
.
append
(
tf
.
TensorSpec
(
shape
=
info
[
0
],
dtype
=
info
[
1
],
name
=
info
[
2
]))
def
__call__
(
input_tensor
,
*
side_inputs
):
...
...
@@ -105,9 +122,8 @@ class DetectionFromImageModule(DetectionInferenceModule):
self
.
__call__
=
tf
.
function
(
__call__
,
input_signature
=
sig
)
super
(
DetectionFromImageModule
,
self
).
__init__
(
detection_model
,
side_input_shapes
,
side_input_types
,
side_input_names
)
use_side_inputs
,
zipped_side_inputs
)
class
DetectionFromFloatImageModule
(
DetectionInferenceModule
):
...
...
@@ -179,6 +195,12 @@ def export_inference_graph(input_type,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
use_side_inputs: boolean that determines whether side inputs should be
included in the input signature.
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Raises:
ValueError: if input_type is invalid.
"""
...
...
@@ -198,11 +220,12 @@ def export_inference_graph(input_type,
raise
ValueError
(
'Unrecognized `input_type`'
)
if
use_side_inputs
and
input_type
!=
'image_tensor'
:
raise
ValueError
(
'Side inputs supported for image_tensor input type only.'
)
detection_module
=
DETECTION_MODULE_MAP
[
input_type
](
detection_model
,
use_side_inputs
,
side_input_shapes
,
zipped_side_inputs
=
_zip_side_inputs
(
side_input_shapes
,
side_input_types
,
side_input_names
)
detection_module
=
DETECTION_MODULE_MAP
[
input_type
](
detection_model
,
use_side_inputs
,
zipped_side_inputs
)
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# saved model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment