Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
42499eb5
Commit
42499eb5
authored
May 12, 2021
by
A. Unique TensorFlower
Browse files
Merge pull request #9971 from anshkumar:master
PiperOrigin-RevId: 373398932
parents
7417c721
a2f397a8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
0 deletions
+261
-0
official/vision/beta/projects/deepmac_maskrcnn/serving/detection.py
...ision/beta/projects/deepmac_maskrcnn/serving/detection.py
+43
-0
official/vision/beta/projects/deepmac_maskrcnn/serving/detection_test.py
.../beta/projects/deepmac_maskrcnn/serving/detection_test.py
+113
-0
official/vision/beta/projects/deepmac_maskrcnn/serving/export_saved_model.py
...a/projects/deepmac_maskrcnn/serving/export_saved_model.py
+105
-0
No files found.
official/vision/beta/projects/deepmac_maskrcnn/serving/detection.py
0 → 100644
View file @
42499eb5
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection input and model functions for serving/inference."""
import
tensorflow
as
tf
from
official.vision.beta.projects.deepmac_maskrcnn.configs
import
deep_mask_head_rcnn
as
cfg
from
official.vision.beta.projects.deepmac_maskrcnn.tasks
import
deep_mask_head_rcnn
from
official.vision.beta.serving
import
detection
class
DetectionModule
(
detection
.
DetectionModule
):
"""Detection Module."""
def
_build_model
(
self
):
if
self
.
_batch_size
is
None
:
ValueError
(
"batch_size can't be None for detection models"
)
if
not
self
.
params
.
task
.
model
.
detection_generator
.
use_batched_nms
:
ValueError
(
'Only batched_nms is supported.'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
if
isinstance
(
self
.
params
.
task
.
model
,
cfg
.
DeepMaskHeadRCNN
):
model
=
deep_mask_head_rcnn
.
build_maskrcnn
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
)
else
:
raise
ValueError
(
'Detection module not implemented for {} model.'
.
format
(
type
(
self
.
params
.
task
.
model
)))
return
model
official/vision/beta/projects/deepmac_maskrcnn/serving/detection_test.py
0 → 100644
View file @
42499eb5
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Test for image detection export lib."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.vision.beta.projects.deepmac_maskrcnn.serving
import
detection
class
DetectionExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_detection_module
(
self
,
experiment_name
):
params
=
exp_factory
.
get_exp_config
(
experiment_name
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
detection_generator
.
use_batched_nms
=
True
detection_module
=
detection
.
DetectionModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
640
,
640
])
return
detection_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
,
batch_size
,
image_size
):
"""Get dummy input for the given input type."""
h
,
w
=
image_size
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
batch_size
,
h
,
w
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
h
,
w
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()
for
b
in
range
(
batch_size
)]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
h
,
w
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
for
b
in
range
(
batch_size
)]
@
parameterized
.
parameters
(
(
'image_tensor'
,
'deep_mask_head_rcnn_resnetfpn_coco'
,
[
640
,
640
]),
(
'image_bytes'
,
'deep_mask_head_rcnn_resnetfpn_coco'
,
[
640
,
384
]),
(
'tf_example'
,
'deep_mask_head_rcnn_resnetfpn_coco'
,
[
640
,
640
]),
)
def
test_export
(
self
,
input_type
,
experiment_name
,
image_size
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_detection_module
(
experiment_name
)
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
detection_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
,
batch_size
=
1
,
image_size
=
image_size
)
processed_images
,
anchor_boxes
,
image_info
=
module
.
_build_inputs
(
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
))
image_shape
=
image_info
[
1
,
:]
image_shape
=
tf
.
expand_dims
(
image_shape
,
0
)
processed_images
=
tf
.
expand_dims
(
processed_images
,
0
)
for
l
,
l_boxes
in
anchor_boxes
.
items
():
anchor_boxes
[
l
]
=
tf
.
expand_dims
(
l_boxes
,
0
)
expected_outputs
=
module
.
model
(
images
=
processed_images
,
image_shape
=
image_shape
,
anchor_boxes
=
anchor_boxes
,
training
=
False
)
outputs
=
detection_fn
(
tf
.
constant
(
images
))
self
.
assertAllClose
(
outputs
[
'num_detections'
].
numpy
(),
expected_outputs
[
'num_detections'
].
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/deepmac_maskrcnn/serving/export_saved_model.py
0 → 100644
View file @
42499eb5
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Deepmac model export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
CONFIG_FILE_PATH = XX
export_saved_model --export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--config_file=${CONFIG_FILE_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from
absl
import
app
from
absl
import
flags
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.projects.deepmac_maskrcnn.serving
import
detection
from
official.vision.beta.projects.deepmac_maskrcnn.tasks
import
deep_mask_head_rcnn
# pylint: disable=unused-import
from
official.vision.beta.serving
import
export_saved_model_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
'deep_mask_head_rcnn_resnetfpn_coco'
,
'experiment type, e.g. retinanet_resnetfpn_coco'
)
flags
.
DEFINE_string
(
'export_dir'
,
None
,
'The export directory.'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'Checkpoint path.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
None
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size.'
)
flags
.
DEFINE_string
(
'input_type'
,
'image_tensor'
,
'One of `image_tensor`, `image_bytes`, `tf_example`.'
)
flags
.
DEFINE_string
(
'input_image_size'
,
'224,224'
,
'The comma-separated string of two integers representing the height,width '
'of the input to the model.'
)
def
main
(
_
):
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
for
config_file
in
FLAGS
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
export_module
=
detection
.
DetectionModule
(
params
=
params
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)],
num_channels
=
3
)
export_saved_model_lib
.
export_inference_graph
(
input_type
=
FLAGS
.
input_type
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)],
params
=
params
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_dir
=
FLAGS
.
export_dir
,
export_module
=
export_module
,
export_checkpoint_subdir
=
'checkpoint'
,
export_saved_model_subdir
=
'saved_model'
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment