Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eb795bf7
"examples/training/train_unconditional.py" did not exist on "d9316bf8bc742dfa7635cf0cb2d5fea0cd1b0d00"
Commit
eb795bf7
authored
Aug 04, 2020
by
Kaushik Shivakumar
Browse files
progress on PR
parent
d08a1c66
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
144 deletions
+83
-144
research/object_detection/core/box_list_ops.py
research/object_detection/core/box_list_ops.py
+43
-1
research/object_detection/core/box_list_ops_test.py
research/object_detection/core/box_list_ops_test.py
+25
-1
research/object_detection/exporter_lib_tf2_test.py
research/object_detection/exporter_lib_tf2_test.py
+3
-45
research/object_detection/exporter_lib_v2.py
research/object_detection/exporter_lib_v2.py
+11
-74
research/object_detection/exporter_main_v2.py
research/object_detection/exporter_main_v2.py
+1
-23
No files found.
research/object_detection/core/box_list_ops.py
View file @
eb795bf7
...
@@ -303,6 +303,49 @@ def iou(boxlist1, boxlist2, scope=None):
...
@@ -303,6 +303,49 @@ def iou(boxlist1, boxlist2, scope=None):
tf
.
equal
(
intersections
,
0.0
),
tf
.
equal
(
intersections
,
0.0
),
tf
.
zeros_like
(
intersections
),
tf
.
truediv
(
intersections
,
unions
))
tf
.
zeros_like
(
intersections
),
tf
.
truediv
(
intersections
,
unions
))
def
l1
(
boxlist1
,
boxlist2
,
scope
=
None
):
"""Computes l1 loss (pairwise) between two boxlists.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing the pairwise L1 loss.
"""
with
tf
.
name_scope
(
scope
,
'PairwiseL1'
):
ycenter1
,
xcenter1
,
h1
,
w1
=
boxlist1
.
get_center_coordinates_and_sizes
()
ycenter2
,
xcenter2
,
h2
,
w2
=
boxlist2
.
get_center_coordinates_and_sizes
()
ycenters
=
tf
.
abs
(
tf
.
expand_dims
(
ycenter2
,
axis
=
0
)
-
tf
.
expand_dims
(
tf
.
transpose
(
ycenter1
),
axis
=
1
))
xcenters
=
tf
.
abs
(
tf
.
expand_dims
(
xcenter2
,
axis
=
0
)
-
tf
.
expand_dims
(
tf
.
transpose
(
xcenter1
),
axis
=
1
))
heights
=
tf
.
abs
(
tf
.
expand_dims
(
h2
,
axis
=
0
)
-
tf
.
expand_dims
(
tf
.
transpose
(
h1
),
axis
=
1
))
widths
=
tf
.
abs
(
tf
.
expand_dims
(
w2
,
axis
=
0
)
-
tf
.
expand_dims
(
tf
.
transpose
(
w1
),
axis
=
1
))
return
ycenters
+
xcenters
+
heights
+
widths
def
giou_loss
(
boxlist1
,
boxlist2
,
scope
=
None
):
"""
Computes generalized IOU loss between two boxlists pairwise,
as described at giou.stanford.edu.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing the pairwise GIoU loss.
"""
with
tf
.
name_scope
(
scope
,
"PairwiseGIoU"
):
N
=
boxlist1
.
num_boxes
()
M
=
boxlist2
.
num_boxes
()
boxes1
=
tf
.
repeat
(
boxlist1
.
get
(),
repeats
=
M
,
axis
=
0
)
boxes2
=
tf
.
tile
(
boxlist2
.
get
(),
multiples
=
[
N
,
1
])
return
tf
.
reshape
(
1.0
-
ops
.
giou
(
boxes1
,
boxes2
),
[
N
,
M
])
def
matched_iou
(
boxlist1
,
boxlist2
,
scope
=
None
):
def
matched_iou
(
boxlist1
,
boxlist2
,
scope
=
None
):
"""Compute intersection-over-union between corresponding boxes in boxlists.
"""Compute intersection-over-union between corresponding boxes in boxlists.
...
@@ -324,7 +367,6 @@ def matched_iou(boxlist1, boxlist2, scope=None):
...
@@ -324,7 +367,6 @@ def matched_iou(boxlist1, boxlist2, scope=None):
tf
.
equal
(
intersections
,
0.0
),
tf
.
equal
(
intersections
,
0.0
),
tf
.
zeros_like
(
intersections
),
tf
.
truediv
(
intersections
,
unions
))
tf
.
zeros_like
(
intersections
),
tf
.
truediv
(
intersections
,
unions
))
def
ioa
(
boxlist1
,
boxlist2
,
scope
=
None
):
def
ioa
(
boxlist1
,
boxlist2
,
scope
=
None
):
"""Computes pairwise intersection-over-area between box collections.
"""Computes pairwise intersection-over-area between box collections.
...
...
research/object_detection/core/box_list_ops_test.py
View file @
eb795bf7
...
@@ -217,7 +217,6 @@ class BoxListOpsTest(test_case.TestCase):
...
@@ -217,7 +217,6 @@ class BoxListOpsTest(test_case.TestCase):
def
test_iou
(
self
):
def
test_iou
(
self
):
def
graph_fn
():
def
graph_fn
():
corners1
=
tf
.
constant
([[
4.0
,
3.0
,
7.0
,
5.0
],
[
5.0
,
6.0
,
10.0
,
7.0
]])
corners1
=
tf
.
constant
([[
4.0
,
3.0
,
7.0
,
5.0
],
[
5.0
,
6.0
,
10.0
,
7.0
]])
corners1
=
tf
.
constant
([[
4.0
,
3.0
,
7.0
,
5.0
],
[
5.0
,
6.0
,
10.0
,
7.0
]])
corners2
=
tf
.
constant
([[
3.0
,
4.0
,
6.0
,
8.0
],
[
14.0
,
14.0
,
15.0
,
15.0
],
corners2
=
tf
.
constant
([[
3.0
,
4.0
,
6.0
,
8.0
],
[
14.0
,
14.0
,
15.0
,
15.0
],
[
0.0
,
0.0
,
20.0
,
20.0
]])
[
0.0
,
0.0
,
20.0
,
20.0
]])
...
@@ -229,6 +228,31 @@ class BoxListOpsTest(test_case.TestCase):
...
@@ -229,6 +228,31 @@ class BoxListOpsTest(test_case.TestCase):
iou_output
=
self
.
execute
(
graph_fn
,
[])
iou_output
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
iou_output
,
exp_output
)
self
.
assertAllClose
(
iou_output
,
exp_output
)
def
test_l1
(
self
):
def
graph_fn
():
corners1
=
tf
.
constant
([[
4.0
,
3.0
,
7.0
,
5.0
],
[
5.0
,
6.0
,
10.0
,
7.0
]])
corners2
=
tf
.
constant
([[
3.0
,
4.0
,
6.0
,
8.0
],
[
14.0
,
14.0
,
15.0
,
15.0
],
[
0.0
,
0.0
,
20.0
,
20.0
]])
boxes1
=
box_list
.
BoxList
(
corners1
)
boxes2
=
box_list
.
BoxList
(
corners2
)
l1
=
box_list_ops
.
l1
(
boxes1
,
boxes2
)
return
l1
exp_output
=
[[
5.0
,
22.5
,
45.5
],
[
8.5
,
19.0
,
40.0
]]
l1_output
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
l1_output
,
exp_output
)
def
test_giou
(
self
):
def
graph_fn
():
corners1
=
tf
.
constant
([[
5.0
,
7.0
,
7.0
,
9.0
]])
corners2
=
tf
.
constant
([[
5.0
,
7.0
,
7.0
,
9.0
],
[
5.0
,
11.0
,
7.0
,
13.0
]])
boxes1
=
box_list
.
BoxList
(
corners1
)
boxes2
=
box_list
.
BoxList
(
corners2
)
giou
=
box_list_ops
.
giou_loss
(
boxes1
,
boxes2
)
return
giou
exp_output
=
[[
0.0
,
4.0
/
3.0
]]
giou_output
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
giou_output
,
exp_output
)
def
test_matched_iou
(
self
):
def
test_matched_iou
(
self
):
def
graph_fn
():
def
graph_fn
():
corners1
=
tf
.
constant
([[
4.0
,
3.0
,
7.0
,
5.0
],
[
5.0
,
6.0
,
10.0
,
7.0
]])
corners1
=
tf
.
constant
([[
4.0
,
3.0
,
7.0
,
5.0
],
[
5.0
,
6.0
,
10.0
,
7.0
]])
...
...
research/object_detection/exporter_lib_tf2_test.py
View file @
eb795bf7
...
@@ -54,11 +54,8 @@ class FakeModel(model.DetectionModel):
...
@@ -54,11 +54,8 @@ class FakeModel(model.DetectionModel):
true_image_shapes
=
[]
# Doesn't matter for the fake model.
true_image_shapes
=
[]
# Doesn't matter for the fake model.
return
tf
.
identity
(
inputs
),
true_image_shapes
return
tf
.
identity
(
inputs
),
true_image_shapes
def
predict
(
self
,
preprocessed_inputs
,
true_image_shapes
,
**
side_inputs
):
def
predict
(
self
,
preprocessed_inputs
,
true_image_shapes
):
return_dict
=
{
'image'
:
self
.
_conv
(
preprocessed_inputs
)}
return
{
'image'
:
self
.
_conv
(
preprocessed_inputs
)}
if
'side_inp'
in
side_inputs
:
return_dict
[
'image'
]
+=
side_inputs
[
'side_inp'
]
return
return_dict
def
postprocess
(
self
,
prediction_dict
,
true_image_shapes
):
def
postprocess
(
self
,
prediction_dict
,
true_image_shapes
):
predict_tensor_sum
=
tf
.
reduce_sum
(
prediction_dict
[
'image'
])
predict_tensor_sum
=
tf
.
reduce_sum
(
prediction_dict
[
'image'
])
...
@@ -192,7 +189,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -192,7 +189,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
saved_model_path
=
os
.
path
.
join
(
output_directory
,
'saved_model'
)
saved_model_path
=
os
.
path
.
join
(
output_directory
,
'saved_model'
)
detect_fn
=
tf
.
saved_model
.
load
(
saved_model_path
)
detect_fn
=
tf
.
saved_model
.
load
(
saved_model_path
)
image
=
self
.
get_dummy_input
(
input_type
)
image
=
self
.
get_dummy_input
(
input_type
)
detections
=
detect_fn
(
tf
.
constant
(
image
)
)
detections
=
detect_fn
(
image
)
detection_fields
=
fields
.
DetectionResultFields
detection_fields
=
fields
.
DetectionResultFields
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_boxes
],
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_boxes
],
...
@@ -206,45 +203,6 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -206,45 +203,6 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
[[
1
,
2
],
[
2
,
1
]])
[[
1
,
2
],
[
2
,
1
]])
self
.
assertAllClose
(
detections
[
detection_fields
.
num_detections
],
[
2
,
1
])
self
.
assertAllClose
(
detections
[
detection_fields
.
num_detections
],
[
2
,
1
])
def
test_export_saved_model_and_run_inference_with_side_inputs
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
self
.
_save_checkpoint_from_mock_model
(
tmp_dir
)
with
mock
.
patch
.
object
(
model_builder
,
'build'
,
autospec
=
True
)
as
mock_builder
:
mock_builder
.
return_value
=
FakeModel
()
output_directory
=
os
.
path
.
join
(
tmp_dir
,
'output'
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
exporter_lib_v2
.
export_inference_graph
(
input_type
=
input_type
,
pipeline_config
=
pipeline_config
,
trained_checkpoint_dir
=
tmp_dir
,
output_directory
=
output_directory
,
use_side_inputs
=
True
,
side_input_shapes
=
"1"
,
side_input_names
=
"side_inp"
,
side_input_types
=
"tf.float32"
)
saved_model_path
=
os
.
path
.
join
(
output_directory
,
'saved_model'
)
detect_fn
=
tf
.
saved_model
.
load
(
saved_model_path
)
detect_fn_sig
=
detect_fn
.
signatures
[
'serving_default'
]
image
=
tf
.
constant
(
self
.
get_dummy_input
(
input_type
))
side_input
=
np
.
ones
((
1
,),
dtype
=
np
.
float32
)
detections
=
detect_fn_sig
(
input_tensor
=
image
,
side_inp
=
tf
.
constant
(
side_input
))
detection_fields
=
fields
.
DetectionResultFields
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_boxes
],
[[[
0.0
,
0.0
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.8
,
0.8
]],
[[
0.5
,
0.5
,
1.0
,
1.0
],
[
0.0
,
0.0
,
0.0
,
0.0
]]])
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_scores
],
[[
400.7
,
400.6
],
[
400.9
,
400.0
]])
self
.
assertAllClose
(
detections
[
detection_fields
.
detection_classes
],
[[
1
,
2
],
[
2
,
1
]])
self
.
assertAllClose
(
detections
[
detection_fields
.
num_detections
],
[
2
,
1
])
def
test_export_checkpoint_and_run_inference_with_image
(
self
):
def
test_export_checkpoint_and_run_inference_with_image
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
tmp_dir
=
self
.
get_temp_dir
()
self
.
_save_checkpoint_from_mock_model
(
tmp_dir
,
conv_weight_scalar
=
2.0
)
self
.
_save_checkpoint_from_mock_model
(
tmp_dir
,
conv_weight_scalar
=
2.0
)
...
...
research/object_detection/exporter_lib_v2.py
View file @
eb795bf7
...
@@ -21,7 +21,7 @@ from object_detection.builders import model_builder
...
@@ -21,7 +21,7 @@ from object_detection.builders import model_builder
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.data_decoders
import
tf_example_decoder
from
object_detection.data_decoders
import
tf_example_decoder
from
object_detection.utils
import
config_util
from
object_detection.utils
import
config_util
import
ast
def
_decode_image
(
encoded_image_string_tensor
):
def
_decode_image
(
encoded_image_string_tensor
):
image_tensor
=
tf
.
image
.
decode_image
(
encoded_image_string_tensor
,
image_tensor
=
tf
.
image
.
decode_image
(
encoded_image_string_tensor
,
...
@@ -36,32 +36,11 @@ def _decode_tf_example(tf_example_string_tensor):
...
@@ -36,32 +36,11 @@ def _decode_tf_example(tf_example_string_tensor):
image_tensor
=
tensor_dict
[
fields
.
InputDataFields
.
image
]
image_tensor
=
tensor_dict
[
fields
.
InputDataFields
.
image
]
return
image_tensor
return
image_tensor
def
_zip_side_inputs
(
side_input_shapes
=
""
,
side_input_types
=
""
,
side_input_names
=
""
):
"""Zips the side inputs together.
Args:
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Returns:
a zipped list of side input tuples.
"""
side_input_shapes
=
list
(
map
(
lambda
x
:
ast
.
literal_eval
(
'['
+
x
+
']'
),
side_input_shapes
.
split
(
'/'
)))
side_input_types
=
eval
(
'['
+
side_input_types
+
']'
)
side_input_names
=
side_input_names
.
split
(
','
)
return
zip
(
side_input_shapes
,
side_input_types
,
side_input_names
)
class
DetectionInferenceModule
(
tf
.
Module
):
class
DetectionInferenceModule
(
tf
.
Module
):
"""Detection Inference Module."""
"""Detection Inference Module."""
def
__init__
(
self
,
detection_model
,
def
__init__
(
self
,
detection_model
):
use_side_inputs
=
False
,
zipped_side_inputs
=
None
):
"""Initializes a module for detection.
"""Initializes a module for detection.
Args:
Args:
...
@@ -69,7 +48,7 @@ class DetectionInferenceModule(tf.Module):
...
@@ -69,7 +48,7 @@ class DetectionInferenceModule(tf.Module):
"""
"""
self
.
_model
=
detection_model
self
.
_model
=
detection_model
def
_run_inference_on_images
(
self
,
image
,
**
kwargs
):
def
_run_inference_on_images
(
self
,
image
):
"""Cast image to float and run inference.
"""Cast image to float and run inference.
Args:
Args:
...
@@ -81,7 +60,7 @@ class DetectionInferenceModule(tf.Module):
...
@@ -81,7 +60,7 @@ class DetectionInferenceModule(tf.Module):
image
=
tf
.
cast
(
image
,
tf
.
float32
)
image
=
tf
.
cast
(
image
,
tf
.
float32
)
image
,
shapes
=
self
.
_model
.
preprocess
(
image
)
image
,
shapes
=
self
.
_model
.
preprocess
(
image
)
prediction_dict
=
self
.
_model
.
predict
(
image
,
shapes
,
**
kwargs
)
prediction_dict
=
self
.
_model
.
predict
(
image
,
shapes
)
detections
=
self
.
_model
.
postprocess
(
prediction_dict
,
shapes
)
detections
=
self
.
_model
.
postprocess
(
prediction_dict
,
shapes
)
classes_field
=
fields
.
DetectionResultFields
.
detection_classes
classes_field
=
fields
.
DetectionResultFields
.
detection_classes
detections
[
classes_field
]
=
(
detections
[
classes_field
]
=
(
...
@@ -96,33 +75,11 @@ class DetectionInferenceModule(tf.Module):
...
@@ -96,33 +75,11 @@ class DetectionInferenceModule(tf.Module):
class
DetectionFromImageModule
(
DetectionInferenceModule
):
class
DetectionFromImageModule
(
DetectionInferenceModule
):
"""Detection Inference Module for image inputs."""
"""Detection Inference Module for image inputs."""
def
__init__
(
self
,
detection_model
,
@
tf
.
function
(
use_side_inputs
=
False
,
input_signature
=
[
zipped_side_inputs
=
None
):
tf
.
TensorSpec
(
shape
=
[
1
,
None
,
None
,
3
],
dtype
=
tf
.
uint8
)])
"""Initializes a module for detection.
def
__call__
(
self
,
input_tensor
):
return
self
.
_run_inference_on_images
(
input_tensor
)
Args:
detection_model: The detection model to use for inference.
"""
self
.
side_input_names
=
[]
sig
=
[
tf
.
TensorSpec
(
shape
=
[
1
,
None
,
None
,
3
],
dtype
=
tf
.
uint8
)]
if
use_side_inputs
:
for
info
in
zipped_side_inputs
:
self
.
side_input_names
.
append
(
info
[
2
])
sig
.
append
(
tf
.
TensorSpec
(
shape
=
info
[
0
],
dtype
=
info
[
1
],
name
=
info
[
2
]))
def
__call__
(
input_tensor
,
*
side_inputs
):
kwargs
=
dict
(
zip
(
self
.
side_input_names
,
side_inputs
))
return
self
.
_run_inference_on_images
(
input_tensor
,
**
kwargs
)
self
.
__call__
=
tf
.
function
(
__call__
,
input_signature
=
sig
)
super
(
DetectionFromImageModule
,
self
).
__init__
(
detection_model
,
use_side_inputs
,
zipped_side_inputs
)
class
DetectionFromFloatImageModule
(
DetectionInferenceModule
):
class
DetectionFromFloatImageModule
(
DetectionInferenceModule
):
...
@@ -176,11 +133,7 @@ DETECTION_MODULE_MAP = {
...
@@ -176,11 +133,7 @@ DETECTION_MODULE_MAP = {
def
export_inference_graph
(
input_type
,
def
export_inference_graph
(
input_type
,
pipeline_config
,
pipeline_config
,
trained_checkpoint_dir
,
trained_checkpoint_dir
,
output_directory
,
output_directory
):
use_side_inputs
=
False
,
side_input_shapes
=
""
,
side_input_types
=
""
,
side_input_names
=
""
):
"""Exports inference graph for the model specified in the pipeline config.
"""Exports inference graph for the model specified in the pipeline config.
This function creates `output_directory` if it does not already exist,
This function creates `output_directory` if it does not already exist,
...
@@ -194,12 +147,6 @@ def export_inference_graph(input_type,
...
@@ -194,12 +147,6 @@ def export_inference_graph(input_type,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
output_directory: Path to write outputs.
use_side_inputs: boolean that determines whether side inputs should be
included in the input signature.
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Raises:
Raises:
ValueError: if input_type is invalid.
ValueError: if input_type is invalid.
"""
"""
...
@@ -217,17 +164,7 @@ def export_inference_graph(input_type,
...
@@ -217,17 +164,7 @@ def export_inference_graph(input_type,
if
input_type
not
in
DETECTION_MODULE_MAP
:
if
input_type
not
in
DETECTION_MODULE_MAP
:
raise
ValueError
(
'Unrecognized `input_type`'
)
raise
ValueError
(
'Unrecognized `input_type`'
)
if
use_side_inputs
and
input_type
!=
'image_tensor'
:
detection_module
=
DETECTION_MODULE_MAP
[
input_type
](
detection_model
)
raise
ValueError
(
'Side inputs supported for image_tensor input type only.'
)
zipped_side_inputs
=
None
if
use_side_inputs
:
zipped_side_inputs
=
_zip_side_inputs
(
side_input_shapes
,
side_input_types
,
side_input_names
)
detection_module
=
DETECTION_MODULE_MAP
[
input_type
](
detection_model
,
use_side_inputs
,
zipped_side_inputs
)
# Getting the concrete function traces the graph and forces variables to
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# be constructed --- only after this can we save the checkpoint and
# saved model.
# saved model.
...
...
research/object_detection/exporter_main_v2.py
View file @
eb795bf7
...
@@ -106,27 +106,6 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
...
@@ -106,27 +106,6 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags
.
DEFINE_string
(
'config_override'
,
''
,
flags
.
DEFINE_string
(
'config_override'
,
''
,
'pipeline_pb2.TrainEvalPipelineConfig '
'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.'
)
'text proto to override pipeline_config_path.'
)
flags
.
DEFINE_boolean
(
'use_side_inputs'
,
False
,
'If True, uses side inputs as well as image inputs.'
)
flags
.
DEFINE_string
(
'side_input_shapes'
,
""
,
'If use_side_inputs is True, this explicitly sets '
'the shape of the side input tensors to a fixed size. The '
'dimensions are to be provided as a comma-separated list '
'of integers. A value of -1 can be used for unknown '
'dimensions. A `/` denotes a break, starting the shape of '
'the next side input tensor. This flag is required if '
'using side inputs.'
)
flags
.
DEFINE_string
(
'side_input_types'
,
""
,
'If use_side_inputs is True, this explicitly sets '
'the type of the side input tensors. The '
'dimensions are to be provided as a comma-separated list '
'of types, each of `string`, `integer`, or `float`. '
'This flag is required if using side inputs.'
)
flags
.
DEFINE_string
(
'side_input_names'
,
""
,
'If use_side_inputs is True, this explicitly sets '
'the names of the side input tensors required by the model '
'assuming the names will be a comma-separated list of '
'strings. This flag is required if using side inputs.'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
flags
.
mark_flag_as_required
(
'trained_checkpoint_dir'
)
flags
.
mark_flag_as_required
(
'trained_checkpoint_dir'
)
...
@@ -140,8 +119,7 @@ def main(_):
...
@@ -140,8 +119,7 @@ def main(_):
text_format
.
Merge
(
FLAGS
.
config_override
,
pipeline_config
)
text_format
.
Merge
(
FLAGS
.
config_override
,
pipeline_config
)
exporter_lib_v2
.
export_inference_graph
(
exporter_lib_v2
.
export_inference_graph
(
FLAGS
.
input_type
,
pipeline_config
,
FLAGS
.
trained_checkpoint_dir
,
FLAGS
.
input_type
,
pipeline_config
,
FLAGS
.
trained_checkpoint_dir
,
FLAGS
.
output_directory
,
FLAGS
.
use_side_inputs
,
FLAGS
.
side_input_shapes
,
FLAGS
.
output_directory
)
FLAGS
.
side_input_types
,
FLAGS
.
side_input_names
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment