Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c633f2c8
Commit
c633f2c8
authored
Jun 09, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
Jun 09, 2021
Browse files
Internal change
PiperOrigin-RevId: 378423112
parent
5fd25faa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
175 additions
and
77 deletions
+175
-77
official/vision/beta/projects/movinet/export_saved_model.py
official/vision/beta/projects/movinet/export_saved_model.py
+73
-77
official/vision/beta/projects/movinet/export_saved_model_test.py
...l/vision/beta/projects/movinet/export_saved_model_test.py
+102
-0
No files found.
official/vision/beta/projects/movinet/export_saved_model.py
View file @
c633f2c8
...
...
@@ -19,38 +19,18 @@ Export example:
```shell
python3 export_saved_model.py \
--
outpu
t_path=/tmp/movinet/ \
--
expor
t_path=/tmp/movinet/ \
--model_id=a0 \
--causal=True \
--conv_type="3d" \
--num_classes=600 \
--use_positional_encoding=False \
--checkpoint_path=""
```
To use an exported saved_model in various applications:
```python
import tensorflow as tf
import tensorflow_hub as hub
saved_model_path = ...
inputs = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32)
encoder = hub.KerasLayer(saved_model_path, trainable=True)
outputs = encoder(inputs)
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input, states)
```
To use an exported saved_model, refer to export_saved_model_test.py.
"""
from
typing
import
Sequence
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
...
...
@@ -59,8 +39,8 @@ from official.vision.beta.projects.movinet.modeling import movinet
from
official.vision.beta.projects.movinet.modeling
import
movinet_model
flags
.
DEFINE_string
(
'
outpu
t_path'
,
'/tmp/movinet/'
,
'
P
ath to save
d exported
saved_model file.'
)
'
expor
t_path'
,
'/tmp/movinet/'
,
'
Export p
ath to save
the
saved_model file.'
)
flags
.
DEFINE_string
(
'model_id'
,
'a0'
,
'MoViNet model name.'
)
flags
.
DEFINE_bool
(
...
...
@@ -73,8 +53,20 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).'
)
flags
.
DEFINE_bool
(
'use_positional_encoding'
,
False
,
'Whether to use positional encoding (only applied when causal=True).'
)
flags
.
DEFINE_integer
(
'num_classes'
,
600
,
'The number of classes for prediction.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size of the input. Set to None for dynamic input.'
)
flags
.
DEFINE_integer
(
'num_frames'
,
None
,
'The number of frames of the input. Set to None for dynamic input.'
)
flags
.
DEFINE_integer
(
'image_size'
,
None
,
'The resolution of the input. Set to None for dynamic input.'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
''
,
'Checkpoint path to load. Leave blank for default initialization.'
)
...
...
@@ -82,75 +74,79 @@ flags.DEFINE_string(
FLAGS
=
flags
.
FLAGS
def
main
(
argv
:
Sequence
[
str
])
->
None
:
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
def
main
(
_
)
->
None
:
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
num_frames
,
FLAGS
.
image_size
,
FLAGS
.
image_size
,
3
,
])
# Use dimensions of 1 except the channels to export faster,
# since we only really need the last dimension to build and get the output
# states. These dimensions will be set to `None` once the model is built.
input_shape
=
[
1
,
1
,
1
,
1
,
3
]
input_shape
=
[
1
if
s
is
None
else
s
for
s
in
input_specs
.
shape
]
backbone
=
movinet
.
Movinet
(
FLAGS
.
model_id
,
causal
=
FLAGS
.
causal
,
conv_type
=
FLAGS
.
conv_type
)
FLAGS
.
model_id
,
causal
=
FLAGS
.
causal
,
conv_type
=
FLAGS
.
conv_type
,
use_external_states
=
FLAGS
.
causal
,
input_specs
=
input_specs
,
use_positional_encoding
=
FLAGS
.
use_positional_encoding
)
model
=
movinet_model
.
MovinetClassifier
(
backbone
,
num_classes
=
FLAGS
.
num_classes
,
output_states
=
FLAGS
.
causal
)
backbone
,
num_classes
=
FLAGS
.
num_classes
,
output_states
=
FLAGS
.
causal
,
input_specs
=
dict
(
image
=
input_specs
))
model
.
build
(
input_shape
)
# Compile model to generate some internal Keras variables.
model
.
compile
()
if
FLAGS
.
checkpoint_path
:
model
.
load_weights
(
FLAGS
.
checkpoint_path
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
status
=
checkpoint
.
restore
(
FLAGS
.
checkpoint_path
)
status
.
assert_existing_objects_matched
()
if
FLAGS
.
causal
:
# Call the model once to get the output states. Call again with `states`
# input to ensure that the inputs with the `states` argument is built
_
,
states
=
model
(
dict
(
image
=
tf
.
ones
(
input_shape
),
states
=
{}))
_
,
states
=
model
(
dict
(
image
=
tf
.
ones
(
input_shape
),
states
=
states
))
input_spec
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
],
dtype
=
tf
.
float32
,
name
=
'inputs'
)
state_specs
=
{}
for
name
,
state
in
states
.
items
():
shape
=
state
.
shape
if
len
(
state
.
shape
)
==
5
:
shape
=
[
None
,
state
.
shape
[
1
],
None
,
None
,
state
.
shape
[
-
1
]]
new_spec
=
tf
.
TensorSpec
(
shape
=
shape
,
dtype
=
state
.
dtype
,
name
=
name
)
state_specs
[
name
]
=
new_spec
specs
=
(
input_spec
,
state_specs
)
# Define a tf.keras.Model with custom signatures to allow it to accept
# a state dict as an argument. We define it inline here because
# we first need to determine the shape of the state tensors before
# applying the `input_signature` argument to `tf.function`.
class
ExportStateModule
(
tf
.
Module
):
"""Module with state for exporting to saved_model."""
def
__init__
(
self
,
model
):
self
.
model
=
model
@
tf
.
function
(
input_signature
=
[
input_spec
])
def
__call__
(
self
,
inputs
):
return
self
.
model
(
dict
(
image
=
inputs
,
states
=
{}))
@
tf
.
function
(
input_signature
=
[
input_spec
])
def
base
(
self
,
inputs
):
return
self
.
model
(
dict
(
image
=
inputs
,
states
=
{}))
@
tf
.
function
(
input_signature
=
specs
)
def
stream
(
self
,
inputs
,
states
):
return
self
.
model
(
dict
(
image
=
inputs
,
states
=
states
))
module
=
ExportStateModule
(
model
)
tf
.
saved_model
.
save
(
module
,
FLAGS
.
output_path
)
# with the full output state shapes.
input_image
=
tf
.
ones
(
input_shape
)
_
,
states
=
model
({
**
model
.
init_states
(
input_shape
),
'image'
:
input_image
})
_
,
states
=
model
({
**
states
,
'image'
:
input_image
})
# Create a function to explicitly set the names of the outputs
def
predict
(
inputs
):
outputs
,
states
=
model
(
inputs
)
return
{
**
states
,
'logits'
:
outputs
}
specs
=
{
name
:
tf
.
TensorSpec
(
spec
.
shape
,
name
=
name
,
dtype
=
spec
.
dtype
)
for
name
,
spec
in
model
.
initial_state_specs
(
input_specs
.
shape
).
items
()
}
specs
[
'image'
]
=
tf
.
TensorSpec
(
input_specs
.
shape
,
dtype
=
model
.
dtype
,
name
=
'image'
)
predict_fn
=
tf
.
function
(
predict
,
jit_compile
=
True
)
predict_fn
=
predict_fn
.
get_concrete_function
(
specs
)
init_states_fn
=
tf
.
function
(
model
.
init_states
,
jit_compile
=
True
)
init_states_fn
=
init_states_fn
.
get_concrete_function
(
tf
.
TensorSpec
([
5
],
dtype
=
tf
.
int32
))
signatures
=
{
'call'
:
predict_fn
,
'init_states'
:
init_states_fn
}
tf
.
keras
.
models
.
save_model
(
model
,
FLAGS
.
export_path
,
signatures
=
signatures
)
else
:
_
=
model
(
tf
.
ones
(
input_shape
))
tf
.
keras
.
models
.
save_model
(
model
,
FLAGS
.
outpu
t_path
)
tf
.
keras
.
models
.
save_model
(
model
,
FLAGS
.
expor
t_path
)
print
(
' ----- Done. Saved Model is saved at {}'
.
format
(
FLAGS
.
outpu
t_path
))
print
(
' ----- Done. Saved Model is saved at {}'
.
format
(
FLAGS
.
expor
t_path
))
if
__name__
==
'__main__'
:
...
...
official/vision/beta/projects/movinet/export_saved_model_test.py
0 → 100644
View file @
c633f2c8
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_saved_model."""
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.vision.beta.projects.movinet
import
export_saved_model
FLAGS
=
flags
.
FLAGS
class
ExportSavedModelTest
(
tf
.
test
.
TestCase
):
def
test_movinet_export_a0_base_with_tfhub
(
self
):
saved_model_path
=
self
.
get_temp_dir
()
FLAGS
.
export_path
=
saved_model_path
FLAGS
.
model_id
=
'a0'
FLAGS
.
causal
=
False
FLAGS
.
num_classes
=
600
export_saved_model
.
main
(
'unused_args'
)
encoder
=
hub
.
KerasLayer
(
saved_model_path
,
trainable
=
True
)
inputs
=
tf
.
keras
.
layers
.
Input
(
shape
=
[
None
,
None
,
None
,
3
],
dtype
=
tf
.
float32
)
outputs
=
encoder
(
dict
(
image
=
inputs
))
model
=
tf
.
keras
.
Model
(
inputs
,
outputs
)
example_input
=
tf
.
ones
([
1
,
8
,
172
,
172
,
3
])
outputs
=
model
(
example_input
)
self
.
assertEqual
(
outputs
.
shape
,
[
1
,
600
])
def
test_movinet_export_a0_stream_with_tfhub
(
self
):
saved_model_path
=
self
.
get_temp_dir
()
FLAGS
.
export_path
=
saved_model_path
FLAGS
.
model_id
=
'a0'
FLAGS
.
causal
=
True
FLAGS
.
num_classes
=
600
export_saved_model
.
main
(
'unused_args'
)
encoder
=
hub
.
KerasLayer
(
saved_model_path
,
trainable
=
True
)
image_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
[
None
,
None
,
None
,
3
],
dtype
=
tf
.
float32
,
name
=
'image'
)
init_states_fn
=
encoder
.
resolved_object
.
signatures
[
'init_states'
]
state_shapes
=
{
name
:
([
s
if
s
>
0
else
None
for
s
in
state
.
shape
],
state
.
dtype
)
for
name
,
state
in
init_states_fn
(
tf
.
constant
([
0
,
0
,
0
,
0
,
3
])).
items
()
}
states_input
=
{
name
:
tf
.
keras
.
Input
(
shape
[
1
:],
dtype
=
dtype
,
name
=
name
)
for
name
,
(
shape
,
dtype
)
in
state_shapes
.
items
()
}
inputs
=
{
**
states_input
,
'image'
:
image_input
}
outputs
=
encoder
(
inputs
)
model
=
tf
.
keras
.
Model
(
inputs
,
outputs
)
example_input
=
tf
.
ones
([
1
,
8
,
172
,
172
,
3
])
frames
=
tf
.
split
(
example_input
,
example_input
.
shape
[
1
],
axis
=
1
)
init_states
=
init_states_fn
(
tf
.
shape
(
example_input
))
expected_outputs
,
_
=
model
({
**
init_states
,
'image'
:
example_input
})
states
=
init_states
for
frame
in
frames
:
outputs
,
states
=
model
({
**
states
,
'image'
:
frame
})
self
.
assertEqual
(
outputs
.
shape
,
[
1
,
600
])
self
.
assertNotEmpty
(
states
)
self
.
assertAllClose
(
outputs
,
expected_outputs
,
1e-5
,
1e-5
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment