Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4d09de12
Commit
4d09de12
authored
Aug 26, 2019
by
A. Unique TensorFlower
Browse files
Merge pull request #7485 from reedwm:mixed_float16_transformer
PiperOrigin-RevId: 265483790
parents
560b3af4
dfcca061
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
127 additions
and
135 deletions
+127
-135
official/transformer/v2/embedding_layer.py
official/transformer/v2/embedding_layer.py
+2
-13
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+4
-2
official/transformer/v2/transformer_layers_test.py
official/transformer/v2/transformer_layers_test.py
+15
-0
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+3
-5
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+1
-0
official/transformer/v2/transformer_test.py
official/transformer/v2/transformer_test.py
+1
-0
research/lstm_object_detection/export_tflite_lstd_graph.py
research/lstm_object_detection/export_tflite_lstd_graph.py
+6
-10
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
...rch/lstm_object_detection/export_tflite_lstd_graph_lib.py
+24
-22
research/lstm_object_detection/export_tflite_lstd_model.py
research/lstm_object_detection/export_tflite_lstd_model.py
+28
-31
research/lstm_object_detection/g3doc/exporting_models.md
research/lstm_object_detection/g3doc/exporting_models.md
+13
-13
research/lstm_object_detection/test_tflite_model.py
research/lstm_object_detection/test_tflite_model.py
+18
-21
research/lstm_object_detection/tflite/BUILD
research/lstm_object_detection/tflite/BUILD
+2
-9
research/lstm_object_detection/tflite/WORKSPACE
research/lstm_object_detection/tflite/WORKSPACE
+6
-0
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
...lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
+0
-5
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
...h/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
+4
-4
No files found.
official/transformer/v2/embedding_layer.py
View file @
4d09de12
...
@@ -24,24 +24,14 @@ import tensorflow as tf
...
@@ -24,24 +24,14 @@ import tensorflow as tf
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
dtype
=
None
):
def
__init__
(
self
,
vocab_size
,
hidden_size
):
"""Specify characteristic parameters of embedding layer.
"""Specify characteristic parameters of embedding layer.
Args:
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
"""
if
dtype
==
tf
.
float16
:
super
(
EmbeddingSharedWeights
,
self
).
__init__
()
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"float16_with_float32_vars"
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
(
dtype
=
dtype
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self
.
shared_weights
=
self
.
add_weight
(
self
.
shared_weights
=
self
.
add_weight
(
"weights"
,
"weights"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
dtype
=
"float32"
,
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
...
...
official/transformer/v2/transformer.py
View file @
4d09de12
...
@@ -49,8 +49,10 @@ def create_model(params, is_train):
...
@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing
=
params
[
"label_smoothing"
]
label_smoothing
=
params
[
"label_smoothing"
]
if
params
[
"enable_metrics_in_training"
]:
if
params
[
"enable_metrics_in_training"
]:
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
)(
logits
)
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss
=
metrics
.
transformer_loss
(
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
label_smoothing
,
vocab_size
)
logits
,
targets
,
label_smoothing
,
vocab_size
)
model
.
add_loss
(
loss
)
model
.
add_loss
(
loss
)
...
@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
...
@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
self
.
params
=
params
self
.
params
=
params
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
params
[
"vocab_size"
],
params
[
"hidden_size"
]
,
dtype
=
params
[
"dtype"
]
)
params
[
"vocab_size"
],
params
[
"hidden_size"
])
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
...
...
official/transformer/v2/transformer_layers_test.py
View file @
4d09de12
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for layers in Transformer."""
"""Tests for layers in Transformer."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
@@ -79,4 +93,5 @@ class TransformerLayersTest(tf.test.TestCase):
...
@@ -79,4 +93,5 @@ class TransformerLayersTest(tf.test.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
tf
.
test
.
main
()
official/transformer/v2/transformer_main.py
View file @
4d09de12
...
@@ -168,8 +168,10 @@ class TransformerTask(object):
...
@@ -168,8 +168,10 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created?
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
# this.
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"
infer
_float
32_vars"
)
"
mixed
_float
16"
,
loss_scale
=
loss_scale
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
...
@@ -417,10 +419,6 @@ class TransformerTask(object):
...
@@ -417,10 +419,6 @@ class TransformerTask(object):
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta2"
],
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
epsilon
=
params
[
"optimizer_adam_epsilon"
])
if
params
[
"dtype"
]
==
tf
.
float16
:
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
loss_scale
=
flags_core
.
get_loss_scale
(
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
return
opt
return
opt
...
...
official/transformer/v2/transformer_main_test.py
View file @
4d09de12
...
@@ -184,4 +184,5 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -184,4 +184,5 @@ class TransformerTaskTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
tf
.
test
.
main
()
official/transformer/v2/transformer_test.py
View file @
4d09de12
...
@@ -65,4 +65,5 @@ class TransformerV2Test(tf.test.TestCase):
...
@@ -65,4 +65,5 @@ class TransformerV2Test(tf.test.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
tf
.
test
.
main
()
research/lstm_object_detection/export_tflite_lstd_graph.py
View file @
4d09de12
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
r
"""Exports an LSTM detection model to use with tf-lite.
r
"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
Outputs file:
...
@@ -85,9 +86,8 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
...
@@ -85,9 +86,8 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
"""
"""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
from
lstm_object_detection.utils
import
config_util
from
lstm_object_detection.utils
import
config_util
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
flags
=
tf
.
app
.
flags
flags
=
tf
.
app
.
flags
flags
.
DEFINE_string
(
'output_directory'
,
None
,
'Path to write outputs.'
)
flags
.
DEFINE_string
(
'output_directory'
,
None
,
'Path to write outputs.'
)
...
@@ -122,16 +122,12 @@ def main(argv):
...
@@ -122,16 +122,12 @@ def main(argv):
flags
.
mark_flag_as_required
(
'trained_checkpoint_prefix'
)
flags
.
mark_flag_as_required
(
'trained_checkpoint_prefix'
)
pipeline_config
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config
=
config_util
.
get_configs_from_pipeline_file
(
FLAGS
.
pipeline_config_path
)
FLAGS
.
pipeline_config_path
)
export_tflite_lstd_graph_lib
.
export_tflite_graph
(
export_tflite_lstd_graph_lib
.
export_tflite_graph
(
pipeline_config
,
pipeline_config
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
output_directory
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
output_directory
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
View file @
4d09de12
...
@@ -12,26 +12,26 @@
...
@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
r
"""Exports detection models to use with tf-lite.
r
"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
See export_tflite_lstd_graph.py for usage.
"""
"""
import
os
import
os
import
tempfile
import
tempfile
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.core.framework
import
attr_value_pb2
from
tensorflow.core.framework
import
attr_value_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.protobuf
import
saver_pb2
from
tensorflow.core.protobuf
import
saver_pb2
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
lstm_object_detection
import
model_builder
from
object_detection
import
exporter
from
object_detection
import
exporter
from
object_detection.builders
import
graph_rewriter_builder
from
object_detection.builders
import
graph_rewriter_builder
from
object_detection.builders
import
post_processing_builder
from
object_detection.builders
import
post_processing_builder
from
object_detection.core
import
box_list
from
object_detection.core
import
box_list
from
lstm_object_detection
import
model_builder
_DEFAULT_NUM_CHANNELS
=
3
_DEFAULT_NUM_CHANNELS
=
3
_DEFAULT_NUM_COORD_BOX
=
4
_DEFAULT_NUM_COORD_BOX
=
4
...
@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def,
...
@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def,
num_classes: number of classes in SSD detector
num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs
scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
Fast NMS.
of
Fast NMS.
Returns:
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
transformed_graph_def: Frozen GraphDef with postprocessing custom op
...
@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
...
@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb.
is written to output_dir/tflite_graph.pb.
Args:
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`,
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`,
`train_config`,
`train_input_config`, `eval_config`, `eval_input_config`,
`train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`.
`lstm_model`.
Value are the corresponding config objects.
Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
output_dir: A directory to write the tflite graph and anchor file to.
...
@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config,
...
@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config,
max_detections: Maximum number of detections (boxes) to show
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
Fast NMS.
of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
txt_graph_name: Name of the exported graph file in text format.
...
@@ -197,10 +197,12 @@ def export_tflite_graph(pipeline_config,
...
@@ -197,10 +197,12 @@ def export_tflite_graph(pipeline_config,
num_classes
=
model_config
.
ssd
.
num_classes
num_classes
=
model_config
.
ssd
.
num_classes
nms_score_threshold
=
{
nms_score_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
score_threshold
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
score_threshold
}
}
nms_iou_threshold
=
{
nms_iou_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
iou_threshold
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
iou_threshold
}
}
scale_values
=
{}
scale_values
=
{}
scale_values
[
'y_scale'
]
=
{
scale_values
[
'y_scale'
]
=
{
...
@@ -224,7 +226,7 @@ def export_tflite_graph(pipeline_config,
...
@@ -224,7 +226,7 @@ def export_tflite_graph(pipeline_config,
width
=
image_resizer_config
.
fixed_shape_resizer
.
width
width
=
image_resizer_config
.
fixed_shape_resizer
.
width
if
image_resizer_config
.
fixed_shape_resizer
.
convert_to_grayscale
:
if
image_resizer_config
.
fixed_shape_resizer
.
convert_to_grayscale
:
num_channels
=
1
num_channels
=
1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape
=
[
lstm_config
.
eval_unroll_length
,
height
,
width
,
num_channels
]
shape
=
[
lstm_config
.
eval_unroll_length
,
height
,
width
,
num_channels
]
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -233,14 +235,14 @@ def export_tflite_graph(pipeline_config,
...
@@ -233,14 +235,14 @@ def export_tflite_graph(pipeline_config,
image_resizer_config
.
WhichOneof
(
'image_resizer_oneof'
)))
image_resizer_config
.
WhichOneof
(
'image_resizer_oneof'
)))
video_tensor
=
tf
.
placeholder
(
video_tensor
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
'input_video_tensor'
)
tf
.
float32
,
shape
=
shape
,
name
=
'input_video_tensor'
)
detection_model
=
model_builder
.
build
(
detection_model
=
model_builder
.
build
(
model_config
,
lstm_config
,
model_config
,
lstm_config
,
is_training
=
False
)
is_training
=
False
)
preprocessed_video
,
true_image_shapes
=
detection_model
.
preprocess
(
preprocessed_video
,
true_image_shapes
=
detection_model
.
preprocess
(
tf
.
to_float
(
video_tensor
))
tf
.
to_float
(
video_tensor
))
predicted_tensors
=
detection_model
.
predict
(
preprocessed_video
,
predicted_tensors
=
detection_model
.
predict
(
preprocessed_video
,
true_image_shapes
)
true_image_shapes
)
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes)
# true_image_shapes)
# The score conversion occurs before the post-processing custom op
# The score conversion occurs before the post-processing custom op
...
@@ -309,7 +311,7 @@ def export_tflite_graph(pipeline_config,
...
@@ -309,7 +311,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes
=
''
)
initializer_nodes
=
''
)
# Add new operation to do post processing in a custom op (TF Lite only)
# Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if
add_postprocessing_op
:
if
add_postprocessing_op
:
transformed_graph_def
=
append_postprocessing_op
(
transformed_graph_def
=
append_postprocessing_op
(
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
...
...
research/lstm_object_detection/export_tflite_lstd_model.py
View file @
4d09de12
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Export a LSTD model in tflite format."""
import
os
import
os
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -31,35 +29,34 @@ FLAGS = flags.FLAGS
...
@@ -31,35 +29,34 @@ FLAGS = flags.FLAGS
def
main
(
_
):
def
main
(
_
):
flags
.
mark_flag_as_required
(
'export_path'
)
flags
.
mark_flag_as_required
(
'export_path'
)
flags
.
mark_flag_as_required
(
'frozen_graph_path'
)
flags
.
mark_flag_as_required
(
'frozen_graph_path'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
configs
=
config_util
.
get_configs_from_pipeline_file
(
FLAGS
.
pipeline_config_path
)
FLAGS
.
pipeline_config_path
)
lstm_config
=
configs
[
'lstm_model'
]
lstm_config
=
configs
[
'lstm_model'
]
input_arrays
=
[
'input_video_tensor'
]
input_arrays
=
[
'input_video_tensor'
]
output_arrays
=
[
output_arrays
=
[
'TFLite_Detection_PostProcess'
,
'TFLite_Detection_PostProcess'
,
'TFLite_Detection_PostProcess:1'
,
'TFLite_Detection_PostProcess:1'
,
'TFLite_Detection_PostProcess:2'
,
'TFLite_Detection_PostProcess:2'
,
'TFLite_Detection_PostProcess:3'
,
'TFLite_Detection_PostProcess:3'
,
]
]
input_shapes
=
{
input_shapes
=
{
'input_video_tensor'
:
[
lstm_config
.
eval_unroll_length
,
320
,
320
,
3
],
'input_video_tensor'
:
[
lstm_config
.
eval_unroll_length
,
320
,
320
,
3
],
}
}
converter
=
tf
.
lite
.
TFLiteConverter
.
from_frozen_graph
(
converter
=
tf
.
lite
.
TFLiteConverter
.
from_frozen_graph
(
FLAGS
.
frozen_graph_path
,
FLAGS
.
frozen_graph_path
,
input_arrays
,
output_arrays
,
input_arrays
,
input_shapes
=
input_shapes
output_arrays
,
)
input_shapes
=
input_shapes
)
converter
.
allow_custom_ops
=
True
converter
.
allow_custom_ops
=
True
tflite_model
=
converter
.
convert
()
tflite_model
=
converter
.
convert
()
ofilename
=
os
.
path
.
join
(
FLAGS
.
export_path
)
ofilename
=
os
.
path
.
join
(
FLAGS
.
export_path
)
open
(
ofilename
,
"wb"
).
write
(
tflite_model
)
open
(
ofilename
,
'wb'
).
write
(
tflite_model
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
tf
.
app
.
run
()
research/lstm_object_detection/g3doc/exporting_models.md
View file @
4d09de12
# Exporting a tflite model from a checkpoint
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
steps:
*
exporting a tflite frozen graph from a checkpoint
*
exporting a tflite model from a frozen graph
*
exporting a tflite frozen graph from a checkpoint
*
exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
## Exporting a tflite frozen graph from a checkpoint
...
@@ -20,14 +20,14 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
...
@@ -20,14 +20,14 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path
${
PIPELINE_CONFIG_PATH
}
\
--pipeline_config_path
${
PIPELINE_CONFIG_PATH
}
\
--trained_checkpoint_prefix
${
TRAINED_CKPT_PREFIX
}
\
--trained_checkpoint_prefix
${
TRAINED_CKPT_PREFIX
}
\
--output_directory
${
EXPORT_DIR
}
\
--output_directory
${
EXPORT_DIR
}
\
--add_preprocessing_op
--add_preprocessing_op
```
```
After export, you should see the directory ${EXPORT_DIR} containing the
After export, you should see the directory ${EXPORT_DIR} containing the following files:
following files:
*
`tflite_graph.pb`
*
`tflite_graph.pbtxt`
*
`tflite_graph.pb`
*
`tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
## Exporting a tflite model from a frozen graph
...
@@ -40,10 +40,10 @@ FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
...
@@ -40,10 +40,10 @@ FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
EXPORT_PATH
={
path to filename that will be used
for
export
}
EXPORT_PATH
={
path to filename that will be used
for
export
}
PIPELINE_CONFIG_PATH
={
path to pipeline config
}
PIPELINE_CONFIG_PATH
={
path to pipeline config
}
python lstm_object_detection/export_tflite_lstd_model.py
\
python lstm_object_detection/export_tflite_lstd_model.py
\
--export_path
${
EXPORT_PATH
}
\
--export_path
${
EXPORT_PATH
}
\
--frozen_graph_path
${
FROZEN_GRAPH_PATH
}
\
--frozen_graph_path
${
FROZEN_GRAPH_PATH
}
\
--pipeline_config_path
${
PIPELINE_CONFIG_PATH
}
--pipeline_config_path
${
PIPELINE_CONFIG_PATH
}
```
```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application.
model to be used by an application.
\ No newline at end of file
research/lstm_object_detection/test_tflite_model.py
View file @
4d09de12
...
@@ -13,9 +13,6 @@
...
@@ -13,9 +13,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Test a tflite model using random input data."""
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -26,28 +23,28 @@ FLAGS = flags.FLAGS
...
@@ -26,28 +23,28 @@ FLAGS = flags.FLAGS
def
main
(
_
):
def
main
(
_
):
flags
.
mark_flag_as_required
(
'model_path'
)
flags
.
mark_flag_as_required
(
'model_path'
)
# Load TFLite model and allocate tensors.
# Load TFLite model and allocate tensors.
interpreter
=
tf
.
lite
.
Interpreter
(
model_path
=
FLAGS
.
model_path
)
interpreter
=
tf
.
lite
.
Interpreter
(
model_path
=
FLAGS
.
model_path
)
interpreter
.
allocate_tensors
()
interpreter
.
allocate_tensors
()
# Get input and output tensors.
# Get input and output tensors.
input_details
=
interpreter
.
get_input_details
()
input_details
=
interpreter
.
get_input_details
()
print
(
'input_details:'
,
input_details
)
print
'input_details:'
,
input_details
output_details
=
interpreter
.
get_output_details
()
output_details
=
interpreter
.
get_output_details
()
print
(
'output_details:'
,
output_details
)
print
'output_details:'
,
output_details
# Test model on random input data.
# Test model on random input data.
input_shape
=
input_details
[
0
][
'shape'
]
input_shape
=
input_details
[
0
][
'shape'
]
# change the following line to feed into your own data.
# change the following line to feed into your own data.
input_data
=
np
.
array
(
np
.
random
.
random_sample
(
input_shape
),
dtype
=
np
.
float32
)
input_data
=
np
.
array
(
np
.
random
.
random_sample
(
input_shape
),
dtype
=
np
.
float32
)
interpreter
.
set_tensor
(
input_details
[
0
][
'index'
],
input_data
)
interpreter
.
set_tensor
(
input_details
[
0
][
'index'
],
input_data
)
interpreter
.
invoke
()
interpreter
.
invoke
()
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
print
(
output_data
)
print
output_data
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
tf
.
app
.
run
()
research/lstm_object_detection/tflite/BUILD
View file @
4d09de12
...
@@ -59,19 +59,12 @@ cc_library(
...
@@ -59,19 +59,12 @@ cc_library(
name
=
"mobile_lstd_tflite_client"
,
name
=
"mobile_lstd_tflite_client"
,
srcs
=
[
"mobile_lstd_tflite_client.cc"
],
srcs
=
[
"mobile_lstd_tflite_client.cc"
],
hdrs
=
[
"mobile_lstd_tflite_client.h"
],
hdrs
=
[
"mobile_lstd_tflite_client.h"
],
defines
=
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"ENABLE_EDGETPU"
],
}),
deps
=
[
deps
=
[
":mobile_ssd_client"
,
":mobile_ssd_client"
,
":mobile_ssd_tflite_client"
,
":mobile_ssd_tflite_client"
,
"@com_google_glog//:glog"
,
"@com_google_absl//absl/base:core_headers"
,
"@com_google_absl//absl/base:core_headers"
,
"@com_google_glog//:glog"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
]
+
select
({
],
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"@libedgetpu//libedgetpu:header"
],
}),
alwayslink
=
1
,
alwayslink
=
1
,
)
)
research/lstm_object_detection/tflite/WORKSPACE
View file @
4d09de12
...
@@ -90,6 +90,12 @@ http_archive(
...
@@ -90,6 +90,12 @@ http_archive(
sha256
=
"79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc"
,
sha256
=
"79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc"
,
)
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
# Needed by TensorFlow
http_archive
(
http_archive
(
...
...
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
View file @
4d09de12
...
@@ -66,11 +66,6 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
...
@@ -66,11 +66,6 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_
->
UseNNAPI
(
false
);
interpreter_
->
UseNNAPI
(
false
);
}
}
#ifdef ENABLE_EDGETPU
interpreter_
->
SetExternalContext
(
kTfLiteEdgeTpuContext
,
edge_tpu_context_
.
get
());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
// raw_inputs/init_lstm_h
if
(
interpreter_
->
inputs
().
size
()
!=
3
)
{
if
(
interpreter_
->
inputs
().
size
()
!=
3
)
{
...
...
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
View file @
4d09de12
...
@@ -76,10 +76,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
...
@@ -76,10 +76,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std
::
unique_ptr
<::
tflite
::
MutableOpResolver
>
resolver_
;
std
::
unique_ptr
<::
tflite
::
MutableOpResolver
>
resolver_
;
std
::
unique_ptr
<::
tflite
::
Interpreter
>
interpreter_
;
std
::
unique_ptr
<::
tflite
::
Interpreter
>
interpreter_
;
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
private:
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient
(
const
MobileSSDTfLiteClient
&
)
=
delete
;
MobileSSDTfLiteClient
(
const
MobileSSDTfLiteClient
&
)
=
delete
;
...
@@ -107,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
...
@@ -107,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool
FloatInference
(
const
uint8_t
*
input_data
);
bool
FloatInference
(
const
uint8_t
*
input_data
);
bool
QuantizedInference
(
const
uint8_t
*
input_data
);
bool
QuantizedInference
(
const
uint8_t
*
input_data
);
void
GetOutputBoxesAndScoreTensorsFromUInt8
();
void
GetOutputBoxesAndScoreTensorsFromUInt8
();
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
};
};
}
// namespace tflite
}
// namespace tflite
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment