Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
201d523a
Commit
201d523a
authored
Sep 23, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Sep 23, 2021
Browse files
Internal change.
PiperOrigin-RevId: 398514419
parent
36ab0686
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
3 deletions
+59
-3
official/vision/beta/serving/export_tflite_lib.py
official/vision/beta/serving/export_tflite_lib.py
+8
-2
official/vision/beta/serving/export_tflite_lib_test.py
official/vision/beta/serving/export_tflite_lib_test.py
+51
-1
No files found.
official/vision/beta/serving/export_tflite_lib.py
View file @
201d523a
...
@@ -21,7 +21,7 @@ import tensorflow as tf
...
@@ -21,7 +21,7 @@ import tensorflow as tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.vision.beta
import
configs
from
official.vision.beta
import
configs
from
official.vision.beta
.tasks
import
image_classification
as
img_cls_
task
from
official.vision.beta
import
task
s
def
create_representative_dataset
(
def
create_representative_dataset
(
...
@@ -39,7 +39,13 @@ def create_representative_dataset(
...
@@ -39,7 +39,13 @@ def create_representative_dataset(
"""
"""
if
isinstance
(
params
.
task
,
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
configs
.
image_classification
.
ImageClassificationTask
):
task
=
img_cls_task
.
ImageClassificationTask
(
params
.
task
)
task
=
tasks
.
image_classification
.
ImageClassificationTask
(
params
.
task
)
elif
isinstance
(
params
.
task
,
configs
.
retinanet
.
RetinaNetTask
):
task
=
tasks
.
retinanet
.
RetinaNetTask
(
params
.
task
)
elif
isinstance
(
params
.
task
,
configs
.
semantic_segmentation
.
SemanticSegmentationTask
):
task
=
tasks
.
semantic_segmentation
.
SemanticSegmentationTask
(
params
.
task
)
else
:
else
:
raise
ValueError
(
'Task {} not supported.'
.
format
(
type
(
params
.
task
)))
raise
ValueError
(
'Task {} not supported.'
.
format
(
type
(
params
.
task
)))
# Ensure batch size is 1 for TFLite model.
# Ensure batch size is 1 for TFLite model.
...
...
official/vision/beta/serving/export_tflite_lib_test.py
View file @
201d523a
...
@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations
...
@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.serving
import
detection
as
detection_serving
from
official.vision.beta.serving
import
export_tflite_lib
from
official.vision.beta.serving
import
export_tflite_lib
from
official.vision.beta.serving
import
image_classification
as
image_classification_serving
from
official.vision.beta.serving
import
image_classification
as
image_classification_serving
from
official.vision.beta.serving
import
semantic_segmentation
as
semantic_segmentation_serving
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
experiment
=
[
'mobilenet_imagenet'
],
experiment
=
[
'mobilenet_imagenet'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
],
input_image_size
=
[[
224
,
224
]]))
input_image_size
=
[[
224
,
224
]]))
def
test_export_tflite
(
self
,
experiment
,
quant_type
,
input_image_size
):
def
test_export_tflite_image_classification
(
self
,
experiment
,
quant_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
_test_tfrecord_file
params
.
task
.
validation_data
.
input_path
=
self
.
_test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
_test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
_test_tfrecord_file
...
@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIsInstance
(
tflite_model
,
bytes
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'retinanet_mobile_coco'
],
quant_type
=
[
None
,
'default'
,
'fp16'
],
input_image_size
=
[[
256
,
256
]]))
def
test_export_tflite_detection
(
self
,
experiment
,
quant_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
temp_dir
=
self
.
get_temp_dir
()
module
=
detection_serving
.
DetectionModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
5
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'seg_deeplabv3_pascal'
],
quant_type
=
[
None
,
'default'
,
'fp16'
],
input_image_size
=
[[
512
,
512
]]))
def
test_export_tflite_semantic_segmentation
(
self
,
experiment
,
quant_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
temp_dir
=
self
.
get_temp_dir
()
module
=
semantic_segmentation_serving
.
SegmentationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
5
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment