Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
073170ba
Commit
073170ba
authored
Oct 21, 2021
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Oct 21, 2021
Browse files
Internal change
PiperOrigin-RevId: 404833082
parent
614d3d93
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
607 additions
and
0 deletions
+607
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+18
-0
official/vision/beta/dataloaders/parser.py
official/vision/beta/dataloaders/parser.py
+12
-0
official/vision/beta/serving/export_base_v2.py
official/vision/beta/serving/export_base_v2.py
+73
-0
official/vision/beta/serving/export_base_v2_test.py
official/vision/beta/serving/export_base_v2_test.py
+89
-0
official/vision/beta/serving/export_module_factory.py
official/vision/beta/serving/export_module_factory.py
+86
-0
official/vision/beta/serving/export_module_factory_test.py
official/vision/beta/serving/export_module_factory_test.py
+115
-0
official/vision/beta/serving/export_saved_model_lib_v2.py
official/vision/beta/serving/export_saved_model_lib_v2.py
+93
-0
official/vision/beta/serving/export_utils.py
official/vision/beta/serving/export_utils.py
+121
-0
No files found.
official/vision/beta/dataloaders/classification_input.py
View file @
073170ba
...
@@ -253,3 +253,21 @@ class Parser(parser.Parser):
...
@@ -253,3 +253,21 @@ class Parser(parser.Parser):
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
return
image
return
image
@
classmethod
def
inference_fn
(
cls
,
image
:
tf
.
Tensor
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
)
->
tf
.
Tensor
:
"""Builds image model inputs for serving."""
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
float32
)
image
=
preprocess_ops
.
center_crop_image
(
image
)
image
=
tf
.
image
.
resize
(
image
,
input_image_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
image
.
set_shape
(
input_image_size
+
[
num_channels
])
return
image
official/vision/beta/dataloaders/parser.py
View file @
073170ba
...
@@ -67,3 +67,15 @@ class Parser(object):
...
@@ -67,3 +67,15 @@ class Parser(object):
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
parse
return
parse
@
classmethod
def
inference_fn
(
cls
,
inputs
):
"""Parses inputs for predictions.
Args:
inputs: A Tensor, or dictionary of Tensors.
Returns:
processed_inputs: An input tensor to the model.
"""
pass
official/vision/beta/serving/export_base_v2.py
0 → 100644
View file @
073170ba
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
from
typing
import
Dict
,
Optional
,
Text
,
Callable
,
Any
,
Union
import
tensorflow
as
tf
from
official.core
import
export_base
class
ExportModule
(
export_base
.
ExportModule
):
"""Base Export Module."""
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
input_signature
:
Union
[
tf
.
TensorSpec
,
Dict
[
str
,
tf
.
TensorSpec
]],
preprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
,
inference_step
:
Optional
[
Callable
[...,
Any
]]
=
None
,
postprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable function to preprocess the inputs.
inference_step: An optional callable function to forward-pass the model.
postprocessor: An optional callable function to postprocess the model
outputs.
"""
super
().
__init__
(
params
,
model
=
model
,
inference_step
=
inference_step
)
self
.
preprocessor
=
preprocessor
self
.
postprocessor
=
postprocessor
self
.
input_signature
=
input_signature
@
tf
.
function
def
serve
(
self
,
inputs
):
x
=
self
.
preprocessor
(
inputs
=
inputs
)
if
self
.
preprocessor
else
inputs
x
=
self
.
inference_step
(
x
)
x
=
self
.
postprocessor
(
x
)
if
self
.
postprocessor
else
x
return
x
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures
=
{}
for
_
,
def_name
in
function_keys
.
items
():
signatures
[
def_name
]
=
self
.
serve
.
get_concrete_function
(
self
.
input_signature
)
return
signatures
official/vision/beta/serving/export_base_v2_test.py
0 → 100644
View file @
073170ba
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base_v2."""
import
os
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.vision.beta.serving
import
export_base_v2
class
TestModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_dense
=
tf
.
keras
.
layers
.
Dense
(
2
)
def
call
(
self
,
inputs
):
return
{
'outputs'
:
self
.
_dense
(
inputs
)}
class
ExportBaseTest
(
tf
.
test
.
TestCase
):
def
test_preprocessor
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
model
=
TestModel
()
inputs
=
tf
.
ones
([
2
,
4
],
tf
.
float32
)
preprocess_fn
=
lambda
inputs
:
2
*
inputs
module
=
export_base_v2
.
ExportModule
(
params
=
None
,
input_signature
=
tf
.
TensorSpec
(
shape
=
[
2
,
4
]),
model
=
model
,
preprocessor
=
preprocess_fn
)
expected_output
=
model
(
preprocess_fn
(
inputs
))
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
'serving_default'
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
imported
=
tf
.
saved_model
.
load
(
export_dir
)
output
=
imported
.
signatures
[
'serving_default'
](
inputs
)
print
(
'output'
,
output
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
expected_output
[
'outputs'
].
numpy
())
def
test_postprocessor
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
model
=
TestModel
()
inputs
=
tf
.
ones
([
2
,
4
],
tf
.
float32
)
postprocess_fn
=
lambda
logits
:
{
'outputs'
:
2
*
logits
[
'outputs'
]}
module
=
export_base_v2
.
ExportModule
(
params
=
None
,
model
=
model
,
input_signature
=
tf
.
TensorSpec
(
shape
=
[
2
,
4
]),
postprocessor
=
postprocess_fn
)
expected_output
=
postprocess_fn
(
model
(
inputs
))
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
'serving_default'
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
imported
=
tf
.
saved_model
.
load
(
export_dir
)
output
=
imported
.
signatures
[
'serving_default'
](
inputs
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
expected_output
[
'outputs'
].
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_module_factory.py
0 → 100644
View file @
073170ba
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for vision export modules."""
from
typing
import
List
,
Optional
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.vision.beta
import
configs
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.serving
import
export_base_v2
as
export_base
from
official.vision.beta.serving
import
export_utils
def
create_classification_export_module
(
params
:
cfg
.
ExperimentConfig
,
input_type
:
str
,
batch_size
:
int
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
):
"""Creats classification export module."""
input_signature
=
export_utils
.
get_image_input_signatures
(
input_type
,
batch_size
,
input_image_size
,
num_channels
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
batch_size
]
+
input_image_size
+
[
num_channels
])
model
=
factory
.
build_classification_model
(
input_specs
=
input_specs
,
model_config
=
params
.
task
.
model
,
l2_regularizer
=
None
)
def
preprocess_fn
(
inputs
):
image_tensor
=
export_utils
.
parse_image
(
inputs
,
input_type
,
input_image_size
,
num_channels
)
def
preprocess_image_fn
(
inputs
):
return
classification_input
.
Parser
.
inference_fn
(
inputs
,
input_image_size
,
num_channels
)
images
=
tf
.
map_fn
(
preprocess_image_fn
,
elems
=
image_tensor
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
input_image_size
+
[
num_channels
],
dtype
=
tf
.
float32
))
return
images
def
postprocess_fn
(
logits
):
probs
=
tf
.
nn
.
softmax
(
logits
)
return
{
'logits'
:
logits
,
'probs'
:
probs
}
export_module
=
export_base
.
ExportModule
(
params
,
model
=
model
,
input_signature
=
input_signature
,
preprocessor
=
preprocess_fn
,
postprocessor
=
postprocess_fn
)
return
export_module
def
get_export_module
(
params
:
cfg
.
ExperimentConfig
,
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
)
->
export_base
.
ExportModule
:
"""Factory for export modules."""
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
export_module
=
create_classification_export_module
(
params
,
input_type
,
batch_size
,
input_image_size
,
num_channels
)
else
:
raise
ValueError
(
'Export module not implemented for {} task.'
.
format
(
type
(
params
.
task
)))
return
export_module
official/vision/beta/serving/export_module_factory_test.py
0 → 100644
View file @
073170ba
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for vision modules."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
export_base
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.serving
import
export_module_factory
class
ImageClassificationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
,
input_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
'resnet_imagenet'
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
module
=
export_module_factory
.
create_classification_export_module
(
params
,
input_type
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
return
module
def
_get_dummy_input
(
self
,
input_type
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
1
,
32
,
32
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
32
,
32
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
32
,
32
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
input_image_size
=
[
32
,
32
]
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
(
input_type
,
input_image_size
)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module
.
model
.
test_trackable
=
tf
.
keras
.
layers
.
InputLayer
(
input_shape
=
(
4
,))
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
module
.
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
input_type
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
export_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
def
preprocess_image_fn
(
inputs
):
return
classification_input
.
Parser
.
inference_fn
(
inputs
,
input_image_size
,
num_channels
=
3
)
processed_images
=
tf
.
map_fn
(
preprocess_image_fn
,
elems
=
tf
.
zeros
([
1
]
+
input_image_size
+
[
3
],
dtype
=
tf
.
uint8
),
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
input_image_size
+
[
3
],
dtype
=
tf
.
float32
))
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertTrue
(
hasattr
(
imported
.
model
,
'test_trackable'
))
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_logits
.
numpy
())
self
.
assertAllClose
(
out
[
'probs'
].
numpy
(),
expected_prob
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_saved_model_lib_v2.py
0 → 100644
View file @
073170ba
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Vision models export utility function for serving/inference."""
import
os
from
typing
import
Optional
,
List
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
export_base
from
official.core
import
train_utils
from
official.vision.beta.serving
import
export_module_factory
def
export
(
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
params
:
cfg
.
ExperimentConfig
,
checkpoint_path
:
str
,
export_dir
:
str
,
num_channels
:
Optional
[
int
]
=
3
,
export_module
:
Optional
[
export_base
.
ExportModule
]
=
None
,
export_checkpoint_subdir
:
Optional
[
str
]
=
None
,
export_saved_model_subdir
:
Optional
[
str
]
=
None
,
save_options
:
Optional
[
tf
.
saved_model
.
SaveOptions
]
=
None
):
"""Exports the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
"""
if
export_checkpoint_subdir
:
output_checkpoint_directory
=
os
.
path
.
join
(
export_dir
,
export_checkpoint_subdir
)
else
:
output_checkpoint_directory
=
None
if
export_saved_model_subdir
:
output_saved_model_directory
=
os
.
path
.
join
(
export_dir
,
export_saved_model_subdir
)
else
:
output_saved_model_directory
=
export_dir
export_module
=
export_module_factory
.
get_export_module
(
params
,
input_type
=
input_type
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
num_channels
=
num_channels
)
export_base
.
export
(
export_module
,
function_keys
=
[
input_type
],
export_savedmodel_dir
=
output_saved_model_directory
,
checkpoint_path
=
checkpoint_path
,
timestamped
=
False
,
save_options
=
save_options
)
if
output_checkpoint_directory
:
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
ckpt
.
save
(
os
.
path
.
join
(
output_checkpoint_directory
,
'ckpt'
))
train_utils
.
serialize_config
(
params
,
export_dir
)
official/vision/beta/serving/export_utils.py
0 → 100644
View file @
073170ba
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper utils for export library."""
from
typing
import
List
,
Optional
import
tensorflow
as
tf
# pylint: disable=g-long-lambda
def
get_image_input_signatures
(
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
):
"""Gets input signatures for an image.
Args:
input_type: A `str`, can be either tf_example, image_bytes, or image_tensor.
batch_size: `int` for batch size or None.
input_image_size: List[int] for the height and width of the input image.
num_channels: `int` for number of channels in the input image.
Returns:
tf.TensorSpec of the input tensor.
"""
if
input_type
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
]
+
[
None
]
*
len
(
input_image_size
)
+
[
num_channels
],
dtype
=
tf
.
uint8
)
elif
input_type
in
[
'image_bytes'
,
'serve_examples'
,
'tf_example'
]:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
],
dtype
=
tf
.
string
)
elif
input_type
==
'tflite'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
1
]
+
input_image_size
+
[
num_channels
],
dtype
=
tf
.
float32
)
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
return
input_signature
def
decode_image
(
encoded_image_bytes
:
str
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
,)
->
tf
.
Tensor
:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
input_image_size: List[int] for the desired input size. This will be used to
infer whether the image is 2d or 3d.
num_channels: `int` for number of image channels.
Returns:
A decoded image tensor.
"""
if
len
(
input_image_size
)
==
2
:
# Decode an image if 2D input is expected.
image_tensor
=
tf
.
image
.
decode_image
(
encoded_image_bytes
,
channels
=
num_channels
)
else
:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor
=
tf
.
io
.
decode_raw
(
encoded_image_bytes
,
out_type
=
tf
.
uint8
)
image_tensor
.
set_shape
([
None
]
*
len
(
input_image_size
)
+
[
num_channels
])
return
image_tensor
def
decode_image_tf_example
(
tf_example_string_tensor
:
tf
.
train
.
Example
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
,
encoded_key
:
str
=
'image/encoded'
)
->
tf
.
Tensor
:
"""Decodes a TF Example to an image tensor."""
keys_to_features
=
{
encoded_key
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
}
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized
=
tf_example_string_tensor
,
features
=
keys_to_features
)
image_tensor
=
decode_image
(
parsed_tensors
[
encoded_key
],
input_image_size
=
input_image_size
,
num_channels
=
num_channels
)
return
image_tensor
def
parse_image
(
inputs
,
input_type
:
str
,
input_image_size
:
List
[
int
],
num_channels
:
int
):
"""Parses image."""
if
input_type
in
[
'tf_example'
,
'serve_examples'
]:
decode_image_tf_example_fn
=
(
lambda
x
:
decode_image_tf_example
(
x
,
input_image_size
,
num_channels
))
image_tensor
=
tf
.
map_fn
(
decode_image_tf_example_fn
,
elems
=
inputs
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
None
]
*
len
(
input_image_size
)
+
[
num_channels
],
dtype
=
tf
.
uint8
),
)
elif
input_type
==
'image_bytes'
:
decode_image_fn
=
lambda
x
:
decode_image
(
x
,
input_image_size
,
num_channels
)
image_tensor
=
tf
.
map_fn
(
decode_image_fn
,
elems
=
inputs
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
None
]
*
len
(
input_image_size
)
+
[
num_channels
],
dtype
=
tf
.
uint8
),)
else
:
image_tensor
=
inputs
return
image_tensor
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment