Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
024ebd81
Commit
024ebd81
authored
Dec 03, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Dec 03, 2021
Browse files
Internal change.
PiperOrigin-RevId: 414045177
parent
002b4ec4
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
169 additions
and
80 deletions
+169
-80
official/vision/beta/serving/detection.py
official/vision/beta/serving/detection.py
+29
-10
official/vision/beta/serving/detection_test.py
official/vision/beta/serving/detection_test.py
+20
-5
official/vision/beta/serving/export_base.py
official/vision/beta/serving/export_base.py
+3
-0
official/vision/beta/serving/export_saved_model_lib.py
official/vision/beta/serving/export_saved_model_lib.py
+4
-0
official/vision/beta/serving/export_tflite_lib_test.py
official/vision/beta/serving/export_tflite_lib_test.py
+47
-21
official/vision/beta/serving/image_classification.py
official/vision/beta/serving/image_classification.py
+14
-12
official/vision/beta/serving/image_classification_test.py
official/vision/beta/serving/image_classification_test.py
+19
-10
official/vision/beta/serving/semantic_segmentation.py
official/vision/beta/serving/semantic_segmentation.py
+14
-12
official/vision/beta/serving/semantic_segmentation_test.py
official/vision/beta/serving/semantic_segmentation_test.py
+19
-10
No files found.
official/vision/beta/serving/detection.py
View file @
024ebd81
...
...
@@ -52,6 +52,18 @@ class DetectionModule(export_base.ExportModule):
return
model
def
_build_anchor_boxes
(
self
):
"""Builds and returns anchor boxes."""
model_params
=
self
.
params
.
task
.
model
input_anchor
=
anchor
.
build_anchor_generator
(
min_level
=
model_params
.
min_level
,
max_level
=
model_params
.
max_level
,
num_scales
=
model_params
.
anchor
.
num_scales
,
aspect_ratios
=
model_params
.
anchor
.
aspect_ratios
,
anchor_size
=
model_params
.
anchor
.
anchor_size
)
return
input_anchor
(
image_size
=
(
self
.
_input_image_size
[
0
],
self
.
_input_image_size
[
1
]))
def
_build_inputs
(
self
,
image
):
"""Builds detection model inputs for serving."""
model_params
=
self
.
params
.
task
.
model
...
...
@@ -67,15 +79,7 @@ class DetectionModule(export_base.ExportModule):
self
.
_input_image_size
,
2
**
model_params
.
max_level
),
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
)
input_anchor
=
anchor
.
build_anchor_generator
(
min_level
=
model_params
.
min_level
,
max_level
=
model_params
.
max_level
,
num_scales
=
model_params
.
anchor
.
num_scales
,
aspect_ratios
=
model_params
.
anchor
.
aspect_ratios
,
anchor_size
=
model_params
.
anchor
.
anchor_size
)
anchor_boxes
=
input_anchor
(
image_size
=
(
self
.
_input_image_size
[
0
],
self
.
_input_image_size
[
1
]))
anchor_boxes
=
self
.
_build_anchor_boxes
()
return
image
,
anchor_boxes
,
image_info
...
...
@@ -133,7 +137,22 @@ class DetectionModule(export_base.ExportModule):
Tensor holding detection output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if
self
.
_input_type
!=
'tflite'
:
images
,
anchor_boxes
,
image_info
=
self
.
preprocess
(
images
)
else
:
with
tf
.
device
(
'cpu:0'
):
anchor_boxes
=
self
.
_build_anchor_boxes
()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info
=
tf
.
convert_to_tensor
([[
self
.
_input_image_size
,
self
.
_input_image_size
,
[
1.0
,
1.0
],
[
0
,
0
]
]],
dtype
=
tf
.
float32
)
input_image_shape
=
image_info
[:,
1
,
:]
# To overcome keras.Model extra limitation to save a model with layers that
...
...
official/vision/beta/serving/detection_test.py
View file @
024ebd81
...
...
@@ -30,12 +30,15 @@ from official.vision.beta.serving import detection
class
DetectionExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_detection_module
(
self
,
experiment_name
):
def
_get_detection_module
(
self
,
experiment_name
,
input_type
):
params
=
exp_factory
.
get_exp_config
(
experiment_name
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
detection_generator
.
nms_version
=
'batched'
detection_module
=
detection
.
DetectionModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
640
,
640
])
params
,
batch_size
=
1
,
input_image_size
=
[
640
,
640
],
input_type
=
input_type
)
return
detection_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
...
...
@@ -65,24 +68,30 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
for
b
in
range
(
batch_size
)]
elif
input_type
==
'tflite'
:
return
tf
.
zeros
((
batch_size
,
h
,
w
,
3
),
dtype
=
np
.
float32
)
@
parameterized
.
parameters
(
(
'image_tensor'
,
'fasterrcnn_resnetfpn_coco'
,
[
384
,
384
]),
(
'image_bytes'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tflite'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
384
]),
(
'tf_example'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tflite'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'retinanet_resnetfpn_coco'
,
[
384
,
640
]),
(
'tflite'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
,
[
384
,
384
]),
(
'image_bytes'
,
'retinanet_spinenet_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'retinanet_spinenet_coco'
,
[
640
,
384
]),
(
'tflite'
,
'retinanet_spinenet_coco'
,
[
640
,
640
]),
)
def
test_export
(
self
,
input_type
,
experiment_name
,
image_size
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_detection_module
(
experiment_name
)
module
=
self
.
_get_detection_module
(
experiment_name
,
input_type
)
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
...
...
@@ -100,6 +109,12 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
images
=
self
.
_get_dummy_input
(
input_type
,
batch_size
=
1
,
image_size
=
image_size
)
if
input_type
==
'tflite'
:
processed_images
=
tf
.
zeros
(
image_size
+
[
3
],
dtype
=
tf
.
float32
)
anchor_boxes
=
module
.
_build_anchor_boxes
()
image_info
=
tf
.
convert_to_tensor
(
[
image_size
,
image_size
,
[
1.0
,
1.0
],
[
0
,
0
]],
dtype
=
tf
.
float32
)
else
:
processed_images
,
anchor_boxes
,
image_info
=
module
.
_build_inputs
(
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
))
image_shape
=
image_info
[
1
,
:]
...
...
official/vision/beta/serving/export_base.py
View file @
024ebd81
...
...
@@ -31,6 +31,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
*
,
batch_size
:
int
,
input_image_size
:
List
[
int
],
input_type
:
str
=
'image_tensor'
,
num_channels
:
int
=
3
,
model
:
Optional
[
tf
.
keras
.
Model
]
=
None
):
"""Initializes a module for export.
...
...
@@ -40,6 +41,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
input_type: The input signature type.
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported.
"""
...
...
@@ -47,6 +49,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
self
.
_batch_size
=
batch_size
self
.
_input_image_size
=
input_image_size
self
.
_num_channels
=
num_channels
self
.
_input_type
=
input_type
if
model
is
None
:
model
=
self
.
_build_model
()
# pylint: disable=assignment-from-none
super
().
__init__
(
params
=
params
,
model
=
model
)
...
...
official/vision/beta/serving/export_saved_model_lib.py
View file @
024ebd81
...
...
@@ -89,6 +89,7 @@ def export_inference_graph(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
retinanet
.
RetinaNetTask
)
or
isinstance
(
params
.
task
,
configs
.
maskrcnn
.
MaskRCNNTask
):
...
...
@@ -96,6 +97,7 @@ def export_inference_graph(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
semantic_segmentation
.
SemanticSegmentationTask
):
...
...
@@ -103,6 +105,7 @@ def export_inference_graph(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
video_classification
.
VideoClassificationTask
):
...
...
@@ -110,6 +113,7 @@ def export_inference_graph(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
else
:
raise
ValueError
(
'Export module not implemented for {} task.'
.
format
(
...
...
official/vision/beta/serving/export_tflite_lib_test.py
View file @
024ebd81
...
...
@@ -30,18 +30,10 @@ from official.vision.beta.serving import semantic_segmentation as semantic_segme
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'test.tfrecord'
)
self
.
_create_test_tfrecord
(
num_samples
=
50
)
def
_create_test_tfrecord
(
self
,
num_samples
):
tfexample_utils
.
dump_to_tfrecord
(
self
.
_test_tfrecord_file
,
[
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
256
,
image_width
=
256
))
for
_
in
range
(
num_samples
)
])
def
_create_test_tfrecord
(
self
,
tfrecord_file
,
example
,
num_samples
):
examples
=
[
example
]
*
num_samples
tfexample_utils
.
dump_to_tfrecord
(
record_file
=
tfrecord_file
,
tf_examples
=
examples
)
def
_export_from_module
(
self
,
module
,
input_type
,
saved_model_dir
):
signatures
=
module
.
get_inference_signatures
(
...
...
@@ -51,16 +43,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'mobilenet_imagenet'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
],
input_image_size
=
[[
224
,
224
]]))
def
test_export_tflite_image_classification
(
self
,
experiment
,
quant_type
,
input_image_size
):
test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'cls_test.tfrecord'
)
example
=
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
input_image_size
[
0
],
image_width
=
input_image_size
[
1
]))
self
.
_create_test_tfrecord
(
tfrecord_file
=
test_tfrecord_file
,
example
=
example
,
num_samples
=
10
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
_
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
_
test_tfrecord_file
params
.
task
.
validation_data
.
input_path
=
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
test_tfrecord_file
temp_dir
=
self
.
get_temp_dir
()
module
=
image_classification_serving
.
ClassificationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
input_type
=
'tflite'
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
...
...
@@ -78,13 +79,26 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
combinations
.
combine
(
experiment
=
[
'retinanet_mobile_coco'
],
quant_type
=
[
None
,
'default'
,
'fp16'
],
input_image_size
=
[[
256
,
256
]]))
input_image_size
=
[[
384
,
384
]]))
def
test_export_tflite_detection
(
self
,
experiment
,
quant_type
,
input_image_size
):
test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'det_test.tfrecord'
)
example
=
tfexample_utils
.
create_detection_test_example
(
image_height
=
input_image_size
[
0
],
image_width
=
input_image_size
[
1
],
image_channel
=
3
,
num_instances
=
10
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
test_tfrecord_file
,
example
=
example
,
num_samples
=
10
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
test_tfrecord_file
temp_dir
=
self
.
get_temp_dir
()
module
=
detection_serving
.
DetectionModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
input_type
=
'tflite'
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
...
...
@@ -100,15 +114,27 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'
seg
_deeplabv3_pascal'
],
quant_type
=
[
None
,
'default'
,
'fp16'
],
experiment
=
[
'
mnv2
_deeplabv3_pascal'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
],
input_image_size
=
[[
512
,
512
]]))
def
test_export_tflite_semantic_segmentation
(
self
,
experiment
,
quant_type
,
input_image_size
):
test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'seg_test.tfrecord'
)
example
=
tfexample_utils
.
create_segmentation_test_example
(
image_height
=
input_image_size
[
0
],
image_width
=
input_image_size
[
1
],
image_channel
=
3
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
test_tfrecord_file
,
example
=
example
,
num_samples
=
10
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
test_tfrecord_file
temp_dir
=
self
.
get_temp_dir
()
module
=
semantic_segmentation_serving
.
SegmentationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
input_type
=
'tflite'
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
...
...
official/vision/beta/serving/image_classification.py
View file @
024ebd81
...
...
@@ -63,18 +63,20 @@ class ClassificationModule(export_base.ExportModule):
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if
self
.
_input_type
!=
'tflite'
:
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
),
parallel_iterations
=
32
)
)
parallel_iterations
=
32
))
logits
=
self
.
inference_step
(
images
)
probs
=
tf
.
nn
.
softmax
(
logits
)
...
...
official/vision/beta/serving/image_classification_test.py
View file @
024ebd81
...
...
@@ -30,11 +30,14 @@ from official.vision.beta.serving import image_classification
class
ImageClassificationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
):
def
_get_classification_module
(
self
,
input_type
):
params
=
exp_factory
.
get_exp_config
(
'resnet_imagenet'
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
classification_module
=
image_classification
.
ClassificationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
224
,
224
])
params
,
batch_size
=
1
,
input_image_size
=
[
224
,
224
],
input_type
=
input_type
)
return
classification_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
...
...
@@ -65,15 +68,18 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
elif
input_type
==
'tflite'
:
return
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
{
'input_type'
:
'tflite'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
()
module
=
self
.
_get_classification_module
(
input_type
)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module
.
model
.
test_trackable
=
tf
.
keras
.
layers
.
InputLayer
(
input_shape
=
(
4
,))
...
...
@@ -90,6 +96,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
if
input_type
!=
'tflite'
:
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
...
...
@@ -97,6 +104,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems
=
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
),
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
224
,
224
,
3
],
dtype
=
tf
.
float32
)))
else
:
processed_images
=
images
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
...
...
official/vision/beta/serving/semantic_segmentation.py
View file @
024ebd81
...
...
@@ -62,18 +62,20 @@ class SegmentationModule(export_base.ExportModule):
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if
self
.
_input_type
!=
'tflite'
:
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
),
parallel_iterations
=
32
)
)
parallel_iterations
=
32
))
masks
=
self
.
inference_step
(
images
)
masks
=
tf
.
image
.
resize
(
masks
,
self
.
_input_image_size
,
method
=
'bilinear'
)
...
...
official/vision/beta/serving/semantic_segmentation_test.py
View file @
024ebd81
...
...
@@ -30,10 +30,13 @@ from official.vision.beta.serving import semantic_segmentation
class
SemanticSegmentationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_segmentation_module
(
self
):
def
_get_segmentation_module
(
self
,
input_type
):
params
=
exp_factory
.
get_exp_config
(
'mnv2_deeplabv3_pascal'
)
segmentation_module
=
semantic_segmentation
.
SegmentationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
112
,
112
])
params
,
batch_size
=
1
,
input_image_size
=
[
112
,
112
],
input_type
=
input_type
)
return
segmentation_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
...
...
@@ -62,15 +65,18 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
elif
input_type
==
'tflite'
:
return
tf
.
zeros
((
1
,
112
,
112
,
3
),
dtype
=
np
.
float32
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
{
'input_type'
:
'tflite'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_segmentation_module
()
module
=
self
.
_get_segmentation_module
(
input_type
)
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
...
...
@@ -86,6 +92,7 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
segmentation_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
if
input_type
!=
'tflite'
:
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
...
...
@@ -93,6 +100,8 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
elems
=
tf
.
zeros
((
1
,
112
,
112
,
3
),
dtype
=
tf
.
uint8
),
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
112
,
112
,
3
],
dtype
=
tf
.
float32
)))
else
:
processed_images
=
images
expected_output
=
tf
.
image
.
resize
(
module
.
model
(
processed_images
,
training
=
False
),
[
112
,
112
],
method
=
'bilinear'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment