Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f46d7b9d
Commit
f46d7b9d
authored
Aug 19, 2022
by
Fan Yang
Committed by
A. Unique TensorFlower
Aug 19, 2022
Browse files
Internal change
PiperOrigin-RevId: 468798864
parent
ef4f89e3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
43 deletions
+59
-43
official/vision/serving/export_tflite_lib_test.py
official/vision/serving/export_tflite_lib_test.py
+59
-43
No files found.
official/vision/serving/export_tflite_lib_test.py
View file @
f46d7b9d
...
@@ -30,6 +30,39 @@ from official.vision.serving import semantic_segmentation as semantic_segmentati
...
@@ -30,6 +30,39 @@ from official.vision.serving import semantic_segmentation as semantic_segmentati
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
# Create test data for image classification.
self
.
test_tfrecord_file_cls
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'cls_test.tfrecord'
)
example
=
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
224
,
image_width
=
224
))
self
.
_create_test_tfrecord
(
tfrecord_file
=
self
.
test_tfrecord_file_cls
,
example
=
example
,
num_samples
=
10
)
# Create test data for object detection.
self
.
test_tfrecord_file_det
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'det_test.tfrecord'
)
example
=
tfexample_utils
.
create_detection_test_example
(
image_height
=
128
,
image_width
=
128
,
image_channel
=
3
,
num_instances
=
10
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
self
.
test_tfrecord_file_det
,
example
=
example
,
num_samples
=
10
)
# Create test data for semantic segmentation.
self
.
test_tfrecord_file_seg
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'seg_test.tfrecord'
)
example
=
tfexample_utils
.
create_segmentation_test_example
(
image_height
=
512
,
image_width
=
512
,
image_channel
=
3
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
self
.
test_tfrecord_file_seg
,
example
=
example
,
num_samples
=
10
)
def
_create_test_tfrecord
(
self
,
tfrecord_file
,
example
,
num_samples
):
def
_create_test_tfrecord
(
self
,
tfrecord_file
,
example
,
num_samples
):
examples
=
[
example
]
*
num_samples
examples
=
[
example
]
*
num_samples
tfexample_utils
.
dump_to_tfrecord
(
tfexample_utils
.
dump_to_tfrecord
(
...
@@ -43,24 +76,18 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -43,24 +76,18 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
experiment
=
[
'mobilenet_imagenet'
],
experiment
=
[
'mobilenet_imagenet'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
]))
input_image_size
=
[[
224
,
224
]]))
def
test_export_tflite_image_classification
(
self
,
experiment
,
quant_type
):
def
test_export_tflite_image_classification
(
self
,
experiment
,
quant_type
,
input_image_size
):
test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'cls_test.tfrecord'
)
example
=
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
input_image_size
[
0
],
image_width
=
input_image_size
[
1
]))
self
.
_create_test_tfrecord
(
tfrecord_file
=
test_tfrecord_file
,
example
=
example
,
num_samples
=
10
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
test_tfrecord_file
params
.
task
.
validation_data
.
input_path
=
self
.
test_tfrecord_file_cls
params
.
task
.
train_data
.
input_path
=
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
test_tfrecord_file_cls
params
.
task
.
train_data
.
shuffle_buffer_size
=
10
temp_dir
=
self
.
get_temp_dir
()
temp_dir
=
self
.
get_temp_dir
()
module
=
image_classification_serving
.
ClassificationModule
(
module
=
image_classification_serving
.
ClassificationModule
(
params
=
params
,
params
=
params
,
batch_size
=
1
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
input_image_size
=
[
224
,
224
]
,
input_type
=
'tflite'
)
input_type
=
'tflite'
)
self
.
_export_from_module
(
self
.
_export_from_module
(
module
=
module
,
module
=
module
,
...
@@ -78,26 +105,22 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -78,26 +105,22 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
experiment
=
[
'retinanet_mobile_coco'
],
experiment
=
[
'retinanet_mobile_coco'
],
quant_type
=
[
None
,
'default'
,
'fp16'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
]))
input_image_size
=
[[
384
,
384
]]))
def
test_export_tflite_detection
(
self
,
experiment
,
quant_type
):
def
test_export_tflite_detection
(
self
,
experiment
,
quant_type
,
input_image_size
):
test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'det_test.tfrecord'
)
example
=
tfexample_utils
.
create_detection_test_example
(
image_height
=
input_image_size
[
0
],
image_width
=
input_image_size
[
1
],
image_channel
=
3
,
num_instances
=
10
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
test_tfrecord_file
,
example
=
example
,
num_samples
=
10
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
test_tfrecord_file
params
.
task
.
validation_data
.
input_path
=
self
.
test_tfrecord_file_det
params
.
task
.
train_data
.
input_path
=
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
test_tfrecord_file_det
params
.
task
.
model
.
num_classes
=
2
params
.
task
.
model
.
backbone
.
spinenet_mobile
.
model_id
=
'49XS'
params
.
task
.
model
.
input_size
=
[
128
,
128
,
3
]
params
.
task
.
model
.
detection_generator
.
nms_version
=
'v1'
params
.
task
.
train_data
.
shuffle_buffer_size
=
5
temp_dir
=
self
.
get_temp_dir
()
temp_dir
=
self
.
get_temp_dir
()
module
=
detection_serving
.
DetectionModule
(
module
=
detection_serving
.
DetectionModule
(
params
=
params
,
params
=
params
,
batch_size
=
1
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
input_image_size
=
[
128
,
128
]
,
input_type
=
'tflite'
)
input_type
=
'tflite'
)
self
.
_export_from_module
(
self
.
_export_from_module
(
module
=
module
,
module
=
module
,
...
@@ -108,32 +131,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -108,32 +131,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
quant_type
=
quant_type
,
params
=
params
,
params
=
params
,
calibration_steps
=
5
)
calibration_steps
=
1
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
experiment
=
[
'mnv2_deeplabv3_pascal'
],
experiment
=
[
'mnv2_deeplabv3_pascal'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
,
'int8_full'
]))
input_image_size
=
[[
512
,
512
]]))
def
test_export_tflite_semantic_segmentation
(
self
,
experiment
,
quant_type
):
def
test_export_tflite_semantic_segmentation
(
self
,
experiment
,
quant_type
,
input_image_size
):
test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'seg_test.tfrecord'
)
example
=
tfexample_utils
.
create_segmentation_test_example
(
image_height
=
input_image_size
[
0
],
image_width
=
input_image_size
[
1
],
image_channel
=
3
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
test_tfrecord_file
,
example
=
example
,
num_samples
=
10
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
test_tfrecord_file
params
.
task
.
validation_data
.
input_path
=
self
.
test_tfrecord_file_seg
params
.
task
.
train_data
.
input_path
=
test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
test_tfrecord_file_seg
params
.
task
.
train_data
.
shuffle_buffer_size
=
10
temp_dir
=
self
.
get_temp_dir
()
temp_dir
=
self
.
get_temp_dir
()
module
=
semantic_segmentation_serving
.
SegmentationModule
(
module
=
semantic_segmentation_serving
.
SegmentationModule
(
params
=
params
,
params
=
params
,
batch_size
=
1
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
input_image_size
=
[
512
,
512
]
,
input_type
=
'tflite'
)
input_type
=
'tflite'
)
self
.
_export_from_module
(
self
.
_export_from_module
(
module
=
module
,
module
=
module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment