Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b641b13
Unverified
Commit
8b641b13
authored
Mar 26, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Mar 26, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
7cffacfe
357fa547
Changes
411
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
0 additions
and
1245 deletions
+0
-1245
official/vision/beta/serving/detection_test.py
official/vision/beta/serving/detection_test.py
+0
-145
official/vision/beta/serving/export_base.py
official/vision/beta/serving/export_base.py
+0
-192
official/vision/beta/serving/export_base_v2.py
official/vision/beta/serving/export_base_v2.py
+0
-75
official/vision/beta/serving/export_base_v2_test.py
official/vision/beta/serving/export_base_v2_test.py
+0
-89
official/vision/beta/serving/export_module_factory.py
official/vision/beta/serving/export_module_factory.py
+0
-89
official/vision/beta/serving/export_module_factory_test.py
official/vision/beta/serving/export_module_factory_test.py
+0
-117
official/vision/beta/serving/export_saved_model.py
official/vision/beta/serving/export_saved_model.py
+0
-107
official/vision/beta/serving/export_saved_model_lib.py
official/vision/beta/serving/export_saved_model_lib.py
+0
-164
official/vision/beta/serving/export_saved_model_lib_test.py
official/vision/beta/serving/export_saved_model_lib_test.py
+0
-69
official/vision/beta/serving/export_saved_model_lib_v2.py
official/vision/beta/serving/export_saved_model_lib_v2.py
+0
-93
official/vision/beta/serving/export_tfhub.py
official/vision/beta/serving/export_tfhub.py
+0
-105
No files found.
Too many changes to show.
To preserve performance only
411 of 411+
files are displayed.
Plain diff
Email patch
official/vision/beta/serving/detection_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Test for image detection export lib."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.vision.beta.serving
import
detection
class
DetectionExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_detection_module
(
self
,
experiment_name
,
input_type
):
params
=
exp_factory
.
get_exp_config
(
experiment_name
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
detection_generator
.
nms_version
=
'batched'
detection_module
=
detection
.
DetectionModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
640
,
640
],
input_type
=
input_type
)
return
detection_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
,
batch_size
,
image_size
):
"""Get dummy input for the given input type."""
h
,
w
=
image_size
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
batch_size
,
h
,
w
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
h
,
w
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()
for
b
in
range
(
batch_size
)]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
h
,
w
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
for
b
in
range
(
batch_size
)]
elif
input_type
==
'tflite'
:
return
tf
.
zeros
((
batch_size
,
h
,
w
,
3
),
dtype
=
np
.
float32
)
@
parameterized
.
parameters
(
(
'image_tensor'
,
'fasterrcnn_resnetfpn_coco'
,
[
384
,
384
]),
(
'image_bytes'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tflite'
,
'fasterrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
384
]),
(
'tf_example'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'tflite'
,
'maskrcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'retinanet_resnetfpn_coco'
,
[
384
,
640
]),
(
'tflite'
,
'retinanet_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_tensor'
,
'retinanet_resnetfpn_coco'
,
[
384
,
384
]),
(
'image_bytes'
,
'retinanet_spinenet_coco'
,
[
640
,
640
]),
(
'tf_example'
,
'retinanet_spinenet_coco'
,
[
640
,
384
]),
(
'tflite'
,
'retinanet_spinenet_coco'
,
[
640
,
640
]),
)
def
test_export
(
self
,
input_type
,
experiment_name
,
image_size
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_detection_module
(
experiment_name
,
input_type
)
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
detection_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
,
batch_size
=
1
,
image_size
=
image_size
)
if
input_type
==
'tflite'
:
processed_images
=
tf
.
zeros
(
image_size
+
[
3
],
dtype
=
tf
.
float32
)
anchor_boxes
=
module
.
_build_anchor_boxes
()
image_info
=
tf
.
convert_to_tensor
(
[
image_size
,
image_size
,
[
1.0
,
1.0
],
[
0
,
0
]],
dtype
=
tf
.
float32
)
else
:
processed_images
,
anchor_boxes
,
image_info
=
module
.
_build_inputs
(
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
))
image_shape
=
image_info
[
1
,
:]
image_shape
=
tf
.
expand_dims
(
image_shape
,
0
)
processed_images
=
tf
.
expand_dims
(
processed_images
,
0
)
for
l
,
l_boxes
in
anchor_boxes
.
items
():
anchor_boxes
[
l
]
=
tf
.
expand_dims
(
l_boxes
,
0
)
expected_outputs
=
module
.
model
(
images
=
processed_images
,
image_shape
=
image_shape
,
anchor_boxes
=
anchor_boxes
,
training
=
False
)
outputs
=
detection_fn
(
tf
.
constant
(
images
))
self
.
assertAllClose
(
outputs
[
'num_detections'
].
numpy
(),
expected_outputs
[
'num_detections'
].
numpy
())
def
test_build_model_fail_with_none_batch_size
(
self
):
params
=
exp_factory
.
get_exp_config
(
'retinanet_resnetfpn_coco'
)
with
self
.
assertRaisesRegex
(
ValueError
,
'batch_size cannot be None for detection models.'
):
detection
.
DetectionModule
(
params
,
batch_size
=
None
,
input_image_size
=
[
640
,
640
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_base.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Base class for model export."""
import
abc
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Text
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
export_base
class
ExportModule
(
export_base
.
ExportModule
,
metaclass
=
abc
.
ABCMeta
):
"""Base Export Module."""
def
__init__
(
self
,
params
:
cfg
.
ExperimentConfig
,
*
,
batch_size
:
int
,
input_image_size
:
List
[
int
],
input_type
:
str
=
'image_tensor'
,
num_channels
:
int
=
3
,
model
:
Optional
[
tf
.
keras
.
Model
]
=
None
):
"""Initializes a module for export.
Args:
params: Experiment params.
batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
input_type: The input signature type.
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported.
"""
self
.
params
=
params
self
.
_batch_size
=
batch_size
self
.
_input_image_size
=
input_image_size
self
.
_num_channels
=
num_channels
self
.
_input_type
=
input_type
if
model
is
None
:
model
=
self
.
_build_model
()
# pylint: disable=assignment-from-none
super
().
__init__
(
params
=
params
,
model
=
model
)
def
_decode_image
(
self
,
encoded_image_bytes
:
str
)
->
tf
.
Tensor
:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
Returns:
A decoded image tensor.
"""
if
len
(
self
.
_input_image_size
)
==
2
:
# Decode an image if 2D input is expected.
image_tensor
=
tf
.
image
.
decode_image
(
encoded_image_bytes
,
channels
=
self
.
_num_channels
)
image_tensor
.
set_shape
((
None
,
None
,
self
.
_num_channels
))
else
:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor
=
tf
.
io
.
decode_raw
(
encoded_image_bytes
,
out_type
=
tf
.
uint8
)
image_tensor
=
tf
.
reshape
(
image_tensor
,
self
.
_input_image_size
+
[
self
.
_num_channels
])
return
image_tensor
def
_decode_tf_example
(
self
,
tf_example_string_tensor
:
tf
.
train
.
Example
)
->
tf
.
Tensor
:
"""Decodes a TF Example to an image tensor.
Args:
tf_example_string_tensor: A tf.train.Example of encoded image and other
information.
Returns:
A decoded image tensor.
"""
keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
)}
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized
=
tf_example_string_tensor
,
features
=
keys_to_features
)
image_tensor
=
self
.
_decode_image
(
parsed_tensors
[
'image/encoded'
])
return
image_tensor
def
_build_model
(
self
,
**
kwargs
):
"""Returns a model built from the params."""
return
None
@
tf
.
function
def
inference_from_image_tensors
(
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
inputs
)
@
tf
.
function
def
inference_for_tflite
(
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
inputs
)
@
tf
.
function
def
inference_from_image_bytes
(
self
,
inputs
:
tf
.
Tensor
):
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_decode_image
,
elems
=
inputs
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
None
]
*
len
(
self
.
_input_image_size
)
+
[
self
.
_num_channels
],
dtype
=
tf
.
uint8
),
parallel_iterations
=
32
))
images
=
tf
.
stack
(
images
)
return
self
.
serve
(
images
)
@
tf
.
function
def
inference_from_tf_example
(
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_decode_tf_example
,
elems
=
inputs
,
# Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model,
# by _run_inference_on_image_tensors() below.
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
None
]
*
len
(
self
.
_input_image_size
)
+
[
self
.
_num_channels
],
dtype
=
tf
.
uint8
),
dtype
=
tf
.
uint8
,
parallel_iterations
=
32
))
images
=
tf
.
stack
(
images
)
return
self
.
serve
(
images
)
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures
=
{}
for
key
,
def_name
in
function_keys
.
items
():
if
key
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
[
None
]
*
len
(
self
.
_input_image_size
)
+
[
self
.
_num_channels
],
dtype
=
tf
.
uint8
)
signatures
[
def_name
]
=
self
.
inference_from_image_tensors
.
get_concrete_function
(
input_signature
)
elif
key
==
'image_bytes'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
],
dtype
=
tf
.
string
)
signatures
[
def_name
]
=
self
.
inference_from_image_bytes
.
get_concrete_function
(
input_signature
)
elif
key
==
'serve_examples'
or
key
==
'tf_example'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
],
dtype
=
tf
.
string
)
signatures
[
def_name
]
=
self
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
elif
key
==
'tflite'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
self
.
_num_channels
],
dtype
=
tf
.
float32
)
signatures
[
def_name
]
=
self
.
inference_for_tflite
.
get_concrete_function
(
input_signature
)
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
return
signatures
official/vision/beta/serving/export_base_v2.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
from
typing
import
Dict
,
Optional
,
Text
,
Callable
,
Any
,
Union
import
tensorflow
as
tf
from
official.core
import
export_base
class
ExportModule
(
export_base
.
ExportModule
):
"""Base Export Module."""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
input_signature
:
Union
[
tf
.
TensorSpec
,
Dict
[
str
,
tf
.
TensorSpec
]],
preprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
,
inference_step
:
Optional
[
Callable
[...,
Any
]]
=
None
,
postprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
"""
super
().
__init__
(
params
,
model
=
model
,
preprocessor
=
preprocessor
,
inference_step
=
inference_step
,
postprocessor
=
postprocessor
)
self
.
input_signature
=
input_signature
@
tf
.
function
def
serve
(
self
,
inputs
):
x
=
self
.
preprocessor
(
inputs
=
inputs
)
if
self
.
preprocessor
else
inputs
x
=
self
.
inference_step
(
x
)
x
=
self
.
postprocessor
(
x
)
if
self
.
postprocessor
else
x
return
x
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures
=
{}
for
_
,
def_name
in
function_keys
.
items
():
signatures
[
def_name
]
=
self
.
serve
.
get_concrete_function
(
self
.
input_signature
)
return
signatures
official/vision/beta/serving/export_base_v2_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base_v2."""
import
os
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.vision.beta.serving
import
export_base_v2
class
TestModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_dense
=
tf
.
keras
.
layers
.
Dense
(
2
)
def
call
(
self
,
inputs
):
return
{
'outputs'
:
self
.
_dense
(
inputs
)}
class
ExportBaseTest
(
tf
.
test
.
TestCase
):
def
test_preprocessor
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
model
=
TestModel
()
inputs
=
tf
.
ones
([
2
,
4
],
tf
.
float32
)
preprocess_fn
=
lambda
inputs
:
2
*
inputs
module
=
export_base_v2
.
ExportModule
(
params
=
None
,
input_signature
=
tf
.
TensorSpec
(
shape
=
[
2
,
4
]),
model
=
model
,
preprocessor
=
preprocess_fn
)
expected_output
=
model
(
preprocess_fn
(
inputs
))
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
'serving_default'
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
imported
=
tf
.
saved_model
.
load
(
export_dir
)
output
=
imported
.
signatures
[
'serving_default'
](
inputs
)
print
(
'output'
,
output
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
expected_output
[
'outputs'
].
numpy
())
def
test_postprocessor
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
model
=
TestModel
()
inputs
=
tf
.
ones
([
2
,
4
],
tf
.
float32
)
postprocess_fn
=
lambda
logits
:
{
'outputs'
:
2
*
logits
[
'outputs'
]}
module
=
export_base_v2
.
ExportModule
(
params
=
None
,
model
=
model
,
input_signature
=
tf
.
TensorSpec
(
shape
=
[
2
,
4
]),
postprocessor
=
postprocess_fn
)
expected_output
=
postprocess_fn
(
model
(
inputs
))
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
'serving_default'
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
imported
=
tf
.
saved_model
.
load
(
export_dir
)
output
=
imported
.
signatures
[
'serving_default'
](
inputs
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
expected_output
[
'outputs'
].
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_module_factory.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for vision export modules."""
from
typing
import
List
,
Optional
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.vision.beta
import
configs
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.serving
import
export_base_v2
as
export_base
from
official.vision.beta.serving
import
export_utils
def
create_classification_export_module
(
params
:
cfg
.
ExperimentConfig
,
input_type
:
str
,
batch_size
:
int
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
):
"""Creats classification export module."""
input_signature
=
export_utils
.
get_image_input_signatures
(
input_type
,
batch_size
,
input_image_size
,
num_channels
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
batch_size
]
+
input_image_size
+
[
num_channels
])
model
=
factory
.
build_classification_model
(
input_specs
=
input_specs
,
model_config
=
params
.
task
.
model
,
l2_regularizer
=
None
)
def
preprocess_fn
(
inputs
):
image_tensor
=
export_utils
.
parse_image
(
inputs
,
input_type
,
input_image_size
,
num_channels
)
# If input_type is `tflite`, do not apply image preprocessing.
if
input_type
==
'tflite'
:
return
image_tensor
def
preprocess_image_fn
(
inputs
):
return
classification_input
.
Parser
.
inference_fn
(
inputs
,
input_image_size
,
num_channels
)
images
=
tf
.
map_fn
(
preprocess_image_fn
,
elems
=
image_tensor
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
input_image_size
+
[
num_channels
],
dtype
=
tf
.
float32
))
return
images
def
postprocess_fn
(
logits
):
probs
=
tf
.
nn
.
softmax
(
logits
)
return
{
'logits'
:
logits
,
'probs'
:
probs
}
export_module
=
export_base
.
ExportModule
(
params
,
model
=
model
,
input_signature
=
input_signature
,
preprocessor
=
preprocess_fn
,
postprocessor
=
postprocess_fn
)
return
export_module
def
get_export_module
(
params
:
cfg
.
ExperimentConfig
,
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
)
->
export_base
.
ExportModule
:
"""Factory for export modules."""
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
export_module
=
create_classification_export_module
(
params
,
input_type
,
batch_size
,
input_image_size
,
num_channels
)
else
:
raise
ValueError
(
'Export module not implemented for {} task.'
.
format
(
type
(
params
.
task
)))
return
export_module
official/vision/beta/serving/export_module_factory_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for vision modules."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
export_base
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.serving
import
export_module_factory
class
ImageClassificationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
,
input_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
'resnet_imagenet'
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
module
=
export_module_factory
.
create_classification_export_module
(
params
,
input_type
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
return
module
def
_get_dummy_input
(
self
,
input_type
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
1
,
32
,
32
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
32
,
32
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
32
,
32
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
input_image_size
=
[
32
,
32
]
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
(
input_type
,
input_image_size
)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module
.
model
.
test_trackable
=
tf
.
keras
.
layers
.
InputLayer
(
input_shape
=
(
4
,))
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
module
.
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
input_type
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
export_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
def
preprocess_image_fn
(
inputs
):
return
classification_input
.
Parser
.
inference_fn
(
inputs
,
input_image_size
,
num_channels
=
3
)
processed_images
=
tf
.
map_fn
(
preprocess_image_fn
,
elems
=
tf
.
zeros
([
1
]
+
input_image_size
+
[
3
],
dtype
=
tf
.
uint8
),
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
input_image_size
+
[
3
],
dtype
=
tf
.
float32
))
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertTrue
(
hasattr
(
imported
.
model
,
'test_trackable'
))
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_logits
.
numpy
(),
rtol
=
1e-04
,
atol
=
1e-04
)
self
.
assertAllClose
(
out
[
'probs'
].
numpy
(),
expected_prob
.
numpy
(),
rtol
=
1e-04
,
atol
=
1e-04
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_saved_model.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r
"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from
absl
import
app
from
absl
import
flags
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.serving
import
export_saved_model_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
)
flags
.
DEFINE_string
(
'export_dir'
,
None
,
'The export directory.'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'Checkpoint path.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
None
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size.'
)
flags
.
DEFINE_string
(
'input_type'
,
'image_tensor'
,
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.'
)
flags
.
DEFINE_string
(
'input_image_size'
,
'224,224'
,
'The comma-separated string of two integers representing the height,width '
'of the input to the model.'
)
flags
.
DEFINE_string
(
'export_checkpoint_subdir'
,
'checkpoint'
,
'The subdirectory for checkpoints.'
)
flags
.
DEFINE_string
(
'export_saved_model_subdir'
,
'saved_model'
,
'The subdirectory for saved model.'
)
def
main
(
_
):
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
for
config_file
in
FLAGS
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
export_saved_model_lib
.
export_inference_graph
(
input_type
=
FLAGS
.
input_type
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)],
params
=
params
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_dir
=
FLAGS
.
export_dir
,
export_checkpoint_subdir
=
FLAGS
.
export_checkpoint_subdir
,
export_saved_model_subdir
=
FLAGS
.
export_saved_model_subdir
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/vision/beta/serving/export_saved_model_lib.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r
"""Vision models export utility function for serving/inference."""
import
os
from
typing
import
Optional
,
List
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
export_base
from
official.core
import
train_utils
from
official.vision.beta
import
configs
from
official.vision.beta.serving
import
detection
from
official.vision.beta.serving
import
image_classification
from
official.vision.beta.serving
import
semantic_segmentation
from
official.vision.beta.serving
import
video_classification
def
export_inference_graph
(
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
params
:
cfg
.
ExperimentConfig
,
checkpoint_path
:
str
,
export_dir
:
str
,
num_channels
:
Optional
[
int
]
=
3
,
export_module
:
Optional
[
export_base
.
ExportModule
]
=
None
,
export_checkpoint_subdir
:
Optional
[
str
]
=
None
,
export_saved_model_subdir
:
Optional
[
str
]
=
None
,
save_options
:
Optional
[
tf
.
saved_model
.
SaveOptions
]
=
None
,
log_model_flops_and_params
:
bool
=
False
):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
and model parameters to model_params.txt.
"""
if
export_checkpoint_subdir
:
output_checkpoint_directory
=
os
.
path
.
join
(
export_dir
,
export_checkpoint_subdir
)
else
:
output_checkpoint_directory
=
None
if
export_saved_model_subdir
:
output_saved_model_directory
=
os
.
path
.
join
(
export_dir
,
export_saved_model_subdir
)
else
:
output_saved_model_directory
=
export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if
not
export_module
:
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
export_module
=
image_classification
.
ClassificationModule
(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
retinanet
.
RetinaNetTask
)
or
isinstance
(
params
.
task
,
configs
.
maskrcnn
.
MaskRCNNTask
):
export_module
=
detection
.
DetectionModule
(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
semantic_segmentation
.
SemanticSegmentationTask
):
export_module
=
semantic_segmentation
.
SegmentationModule
(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
video_classification
.
VideoClassificationTask
):
export_module
=
video_classification
.
VideoClassificationModule
(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
input_type
=
input_type
,
num_channels
=
num_channels
)
else
:
raise
ValueError
(
'Export module not implemented for {} task.'
.
format
(
type
(
params
.
task
)))
export_base
.
export
(
export_module
,
function_keys
=
[
input_type
],
export_savedmodel_dir
=
output_saved_model_directory
,
checkpoint_path
=
checkpoint_path
,
timestamped
=
False
,
save_options
=
save_options
)
if
output_checkpoint_directory
:
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
ckpt
.
save
(
os
.
path
.
join
(
output_checkpoint_directory
,
'ckpt'
))
train_utils
.
serialize_config
(
params
,
export_dir
)
if
log_model_flops_and_params
:
inputs_kwargs
=
None
if
isinstance
(
params
.
task
,
(
configs
.
retinanet
.
RetinaNetTask
,
configs
.
maskrcnn
.
MaskRCNNTask
)):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
inputs_kwargs
=
{
'images'
:
tf
.
TensorSpec
([
1
]
+
input_image_size
+
[
num_channels
],
tf
.
float32
),
'image_shape'
:
tf
.
TensorSpec
([
1
,
2
],
tf
.
float32
)
}
dummy_inputs
=
{
k
:
tf
.
ones
(
v
.
shape
.
as_list
(),
tf
.
float32
)
for
k
,
v
in
inputs_kwargs
.
items
()
}
# Must do forward pass to build the model.
export_module
.
model
(
**
dummy_inputs
)
else
:
logging
.
info
(
'Logging model flops and params not implemented for %s task.'
,
type
(
params
.
task
))
return
train_utils
.
try_count_flops
(
export_module
.
model
,
inputs_kwargs
,
os
.
path
.
join
(
export_dir
,
'model_flops.txt'
))
train_utils
.
write_model_params
(
export_module
.
model
,
os
.
path
.
join
(
export_dir
,
'model_params.txt'
))
official/vision/beta/serving/export_saved_model_lib_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_saved_model_lib."""
import
os
from
unittest
import
mock
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.vision.beta
import
configs
from
official.vision.beta.serving
import
export_saved_model_lib
class
WriteModelFlopsAndParamsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
tempdir
=
self
.
create_tempdir
()
self
.
enter_context
(
mock
.
patch
.
object
(
export_base
,
'export'
,
autospec
=
True
,
spec_set
=
True
))
def
_export_model_with_log_model_flops_and_params
(
self
,
params
):
export_saved_model_lib
.
export_inference_graph
(
input_type
=
'image_tensor'
,
batch_size
=
1
,
input_image_size
=
[
64
,
64
],
params
=
params
,
checkpoint_path
=
os
.
path
.
join
(
self
.
tempdir
,
'unused-ckpt'
),
export_dir
=
self
.
tempdir
,
log_model_flops_and_params
=
True
)
def
assertModelAnalysisFilesExist
(
self
):
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
tempdir
,
'model_params.txt'
)))
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
tempdir
,
'model_flops.txt'
)))
def
test_retinanet_task
(
self
):
params
=
configs
.
retinanet
.
retinanet_resnetfpn_coco
()
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
num_classes
=
2
params
.
task
.
model
.
max_level
=
6
self
.
_export_model_with_log_model_flops_and_params
(
params
)
self
.
assertModelAnalysisFilesExist
()
def
test_maskrcnn_task
(
self
):
params
=
configs
.
maskrcnn
.
maskrcnn_resnetfpn_coco
()
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
num_classes
=
2
params
.
task
.
model
.
max_level
=
6
self
.
_export_model_with_log_model_flops_and_params
(
params
)
self
.
assertModelAnalysisFilesExist
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_saved_model_lib_v2.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Vision models export utility function for serving/inference."""
import
os
from
typing
import
Optional
,
List
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
export_base
from
official.core
import
train_utils
from
official.vision.beta.serving
import
export_module_factory
def
export
(
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
params
:
cfg
.
ExperimentConfig
,
checkpoint_path
:
str
,
export_dir
:
str
,
num_channels
:
Optional
[
int
]
=
3
,
export_module
:
Optional
[
export_base
.
ExportModule
]
=
None
,
export_checkpoint_subdir
:
Optional
[
str
]
=
None
,
export_saved_model_subdir
:
Optional
[
str
]
=
None
,
save_options
:
Optional
[
tf
.
saved_model
.
SaveOptions
]
=
None
):
"""Exports the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
"""
if
export_checkpoint_subdir
:
output_checkpoint_directory
=
os
.
path
.
join
(
export_dir
,
export_checkpoint_subdir
)
else
:
output_checkpoint_directory
=
None
if
export_saved_model_subdir
:
output_saved_model_directory
=
os
.
path
.
join
(
export_dir
,
export_saved_model_subdir
)
else
:
output_saved_model_directory
=
export_dir
export_module
=
export_module_factory
.
get_export_module
(
params
,
input_type
=
input_type
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
num_channels
=
num_channels
)
export_base
.
export
(
export_module
,
function_keys
=
[
input_type
],
export_savedmodel_dir
=
output_saved_model_directory
,
checkpoint_path
=
checkpoint_path
,
timestamped
=
False
,
save_options
=
save_options
)
if
output_checkpoint_directory
:
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
ckpt
.
save
(
os
.
path
.
join
(
output_checkpoint_directory
,
'ckpt'
))
train_utils
.
serialize_config
(
params
,
export_dir
)
official/vision/beta/serving/export_tfhub.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""A script to export the image classification as a TF-Hub SavedModel."""
# Import libraries
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling
import
factory
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. resnet_imagenet'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'Checkpoint path.'
)
flags
.
DEFINE_string
(
'export_path'
,
None
,
'The export directory.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
None
,
'A YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size.'
)
flags
.
DEFINE_string
(
'input_image_size'
,
'224,224'
,
'The comma-separated string of two integers representing the height,width '
'of the input to the model.'
)
flags
.
DEFINE_boolean
(
'skip_logits_layer'
,
False
,
'Whether to skip the prediction layer and only output the feature vector.'
)
def
export_model_to_tfhub
(
params
,
batch_size
,
input_image_size
,
skip_logits_layer
,
checkpoint_path
,
export_path
):
"""Export an image classification model to TF-Hub."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
batch_size
]
+
input_image_size
+
[
3
])
model
=
factory
.
build_classification_model
(
input_specs
=
input_specs
,
model_config
=
params
.
task
.
model
,
l2_regularizer
=
None
,
skip_logits_layer
=
skip_logits_layer
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
.
restore
(
checkpoint_path
).
assert_existing_objects_matched
()
model
.
save
(
export_path
,
include_optimizer
=
False
,
save_format
=
'tf'
)
def
main
(
_
):
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
for
config_file
in
FLAGS
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
export_model_to_tfhub
(
params
=
params
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)],
skip_logits_layer
=
FLAGS
.
skip_logits_layer
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_path
=
FLAGS
.
export_path
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
Prev
1
…
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment