Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4d09de12
Commit
4d09de12
authored
Aug 26, 2019
by
A. Unique TensorFlower
Browse files
Merge pull request #7485 from reedwm:mixed_float16_transformer
PiperOrigin-RevId: 265483790
parents
560b3af4
dfcca061
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
127 additions
and
135 deletions
+127
-135
official/transformer/v2/embedding_layer.py
official/transformer/v2/embedding_layer.py
+2
-13
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+4
-2
official/transformer/v2/transformer_layers_test.py
official/transformer/v2/transformer_layers_test.py
+15
-0
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+3
-5
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+1
-0
official/transformer/v2/transformer_test.py
official/transformer/v2/transformer_test.py
+1
-0
research/lstm_object_detection/export_tflite_lstd_graph.py
research/lstm_object_detection/export_tflite_lstd_graph.py
+6
-10
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
...rch/lstm_object_detection/export_tflite_lstd_graph_lib.py
+24
-22
research/lstm_object_detection/export_tflite_lstd_model.py
research/lstm_object_detection/export_tflite_lstd_model.py
+28
-31
research/lstm_object_detection/g3doc/exporting_models.md
research/lstm_object_detection/g3doc/exporting_models.md
+13
-13
research/lstm_object_detection/test_tflite_model.py
research/lstm_object_detection/test_tflite_model.py
+18
-21
research/lstm_object_detection/tflite/BUILD
research/lstm_object_detection/tflite/BUILD
+2
-9
research/lstm_object_detection/tflite/WORKSPACE
research/lstm_object_detection/tflite/WORKSPACE
+6
-0
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
...lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
+0
-5
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
...h/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
+4
-4
No files found.
official/transformer/v2/embedding_layer.py
View file @
4d09de12
...
...
@@ -24,24 +24,14 @@ import tensorflow as tf
class
EmbeddingSharedWeights
(
tf
.
keras
.
layers
.
Layer
):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
dtype
=
None
):
def
__init__
(
self
,
vocab_size
,
hidden_size
):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
if
dtype
==
tf
.
float16
:
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"float16_with_float32_vars"
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
(
dtype
=
dtype
)
super
(
EmbeddingSharedWeights
,
self
).
__init__
()
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
...
...
@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self
.
shared_weights
=
self
.
add_weight
(
"weights"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
dtype
=
"float32"
,
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
EmbeddingSharedWeights
,
self
).
build
(
input_shape
)
...
...
official/transformer/v2/transformer.py
View file @
4d09de12
...
...
@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing
=
params
[
"label_smoothing"
]
if
params
[
"enable_metrics_in_training"
]:
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
)(
logits
)
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
label_smoothing
,
vocab_size
)
model
.
add_loss
(
loss
)
...
...
@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super
(
Transformer
,
self
).
__init__
(
name
=
name
)
self
.
params
=
params
self
.
embedding_softmax_layer
=
embedding_layer
.
EmbeddingSharedWeights
(
params
[
"vocab_size"
],
params
[
"hidden_size"
]
,
dtype
=
params
[
"dtype"
]
)
params
[
"vocab_size"
],
params
[
"hidden_size"
])
self
.
encoder_stack
=
EncoderStack
(
params
)
self
.
decoder_stack
=
DecoderStack
(
params
)
...
...
official/transformer/v2/transformer_layers_test.py
View file @
4d09de12
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for layers in Transformer."""
from
__future__
import
absolute_import
...
...
@@ -79,4 +93,5 @@ class TransformerLayersTest(tf.test.TestCase):
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
official/transformer/v2/transformer_main.py
View file @
4d09de12
...
...
@@ -168,8 +168,10 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"
infer
_float
32_vars"
)
"
mixed
_float
16"
,
loss_scale
=
loss_scale
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
...
...
@@ -417,10 +419,6 @@ class TransformerTask(object):
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
if
params
[
"dtype"
]
==
tf
.
float16
:
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
loss_scale
=
flags_core
.
get_loss_scale
(
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
return
opt
...
...
official/transformer/v2/transformer_main_test.py
View file @
4d09de12
...
...
@@ -184,4 +184,5 @@ class TransformerTaskTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
official/transformer/v2/transformer_test.py
View file @
4d09de12
...
...
@@ -65,4 +65,5 @@ class TransformerV2Test(tf.test.TestCase):
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
research/lstm_object_detection/export_tflite_lstd_graph.py
View file @
4d09de12
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
...
...
@@ -85,9 +86,8 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
"""
import
tensorflow
as
tf
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
from
lstm_object_detection.utils
import
config_util
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
flags
=
tf
.
app
.
flags
flags
.
DEFINE_string
(
'output_directory'
,
None
,
'Path to write outputs.'
)
...
...
@@ -125,13 +125,9 @@ def main(argv):
FLAGS
.
pipeline_config_path
)
export_tflite_lstd_graph_lib
.
export_tflite_graph
(
pipeline_config
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
output_directory
,
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
pipeline_config
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
output_directory
,
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
View file @
4d09de12
...
...
@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import
os
import
tempfile
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.core.framework
import
attr_value_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.protobuf
import
saver_pb2
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
lstm_object_detection
import
model_builder
from
object_detection
import
exporter
from
object_detection.builders
import
graph_rewriter_builder
from
object_detection.builders
import
post_processing_builder
from
object_detection.core
import
box_list
from
lstm_object_detection
import
model_builder
_DEFAULT_NUM_CHANNELS
=
3
_DEFAULT_NUM_COORD_BOX
=
4
...
...
@@ -87,8 +87,8 @@ def append_postprocessing_op(frozen_graph_def,
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
Fast NMS.
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
...
...
@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`,
`train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`.
Value are the corresponding config objects.
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`,
`train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`.
Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
...
...
@@ -177,8 +177,8 @@ def export_tflite_graph(pipeline_config,
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
Fast NMS.
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
...
...
@@ -197,10 +197,12 @@ def export_tflite_graph(pipeline_config,
num_classes
=
model_config
.
ssd
.
num_classes
nms_score_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
score_threshold
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
score_threshold
}
nms_iou_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
iou_threshold
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
iou_threshold
}
scale_values
=
{}
scale_values
[
'y_scale'
]
=
{
...
...
@@ -224,7 +226,7 @@ def export_tflite_graph(pipeline_config,
width
=
image_resizer_config
.
fixed_shape_resizer
.
width
if
image_resizer_config
.
fixed_shape_resizer
.
convert_to_grayscale
:
num_channels
=
1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape
=
[
lstm_config
.
eval_unroll_length
,
height
,
width
,
num_channels
]
else
:
raise
ValueError
(
...
...
@@ -235,8 +237,8 @@ def export_tflite_graph(pipeline_config,
video_tensor
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
'input_video_tensor'
)
detection_model
=
model_builder
.
build
(
model_config
,
lstm_config
,
is_training
=
False
)
detection_model
=
model_builder
.
build
(
model_config
,
lstm_config
,
is_training
=
False
)
preprocessed_video
,
true_image_shapes
=
detection_model
.
preprocess
(
tf
.
to_float
(
video_tensor
))
predicted_tensors
=
detection_model
.
predict
(
preprocessed_video
,
...
...
@@ -309,7 +311,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes
=
''
)
# Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if
add_postprocessing_op
:
transformed_graph_def
=
append_postprocessing_op
(
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
...
...
research/lstm_object_detection/export_tflite_lstd_model.py
View file @
4d09de12
...
...
@@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Export a LSTD model in tflite format."""
import
os
from
absl
import
flags
import
tensorflow
as
tf
...
...
@@ -51,14 +49,13 @@ def main(_):
}
converter
=
tf
.
lite
.
TFLiteConverter
.
from_frozen_graph
(
FLAGS
.
frozen_graph_path
,
input_arrays
,
output_arrays
,
input_shapes
=
input_shapes
)
FLAGS
.
frozen_graph_path
,
input_arrays
,
output_arrays
,
input_shapes
=
input_shapes
)
converter
.
allow_custom_ops
=
True
tflite_model
=
converter
.
convert
()
ofilename
=
os
.
path
.
join
(
FLAGS
.
export_path
)
open
(
ofilename
,
'
wb
'
).
write
(
tflite_model
)
open
(
ofilename
,
"
wb
"
).
write
(
tflite_model
)
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/g3doc/exporting_models.md
View file @
4d09de12
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2
steps:
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
*
exporting a tflite frozen graph from a checkpoint
*
exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
...
...
@@ -23,12 +23,12 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the
following files:
After export, you should see the directory ${EXPORT_DIR} containing the following files:
*
`tflite_graph.pb`
*
`tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
...
...
research/lstm_object_detection/test_tflite_model.py
View file @
4d09de12
...
...
@@ -13,9 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Test a tflite model using random input data."""
from
__future__
import
print_function
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -34,9 +31,9 @@ def main(_):
# Get input and output tensors.
input_details
=
interpreter
.
get_input_details
()
print
(
'input_details:'
,
input_details
)
print
'input_details:'
,
input_details
output_details
=
interpreter
.
get_output_details
()
print
(
'output_details:'
,
output_details
)
print
'output_details:'
,
output_details
# Test model on random input data.
input_shape
=
input_details
[
0
][
'shape'
]
...
...
@@ -46,7 +43,7 @@ def main(_):
interpreter
.
invoke
()
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
print
(
output_data
)
print
output_data
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/tflite/BUILD
View file @
4d09de12
...
...
@@ -59,19 +59,12 @@ cc_library(
name
=
"mobile_lstd_tflite_client"
,
srcs
=
[
"mobile_lstd_tflite_client.cc"
],
hdrs
=
[
"mobile_lstd_tflite_client.h"
],
defines
=
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"ENABLE_EDGETPU"
],
}),
deps
=
[
":mobile_ssd_client"
,
":mobile_ssd_tflite_client"
,
"@com_google_glog//:glog"
,
"@com_google_absl//absl/base:core_headers"
,
"@com_google_glog//:glog"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
]
+
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"@libedgetpu//libedgetpu:header"
],
}),
],
alwayslink
=
1
,
)
research/lstm_object_detection/tflite/WORKSPACE
View file @
4d09de12
...
...
@@ -90,6 +90,12 @@ http_archive(
sha256
=
"79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc"
,
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
http_archive
(
...
...
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
View file @
4d09de12
...
...
@@ -66,11 +66,6 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_
->
UseNNAPI
(
false
);
}
#ifdef ENABLE_EDGETPU
interpreter_
->
SetExternalContext
(
kTfLiteEdgeTpuContext
,
edge_tpu_context_
.
get
());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
if
(
interpreter_
->
inputs
().
size
()
!=
3
)
{
...
...
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
View file @
4d09de12
...
...
@@ -76,10 +76,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std
::
unique_ptr
<::
tflite
::
MutableOpResolver
>
resolver_
;
std
::
unique_ptr
<::
tflite
::
Interpreter
>
interpreter_
;
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient
(
const
MobileSSDTfLiteClient
&
)
=
delete
;
...
...
@@ -107,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool
FloatInference
(
const
uint8_t
*
input_data
);
bool
QuantizedInference
(
const
uint8_t
*
input_data
);
void
GetOutputBoxesAndScoreTensorsFromUInt8
();
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
};
}
// namespace tflite
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment