Commit 7c732da7 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents cb8ce606 e36934b3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from absl import flags
import tensorflow as tf
from lstm_object_detection.utils import config_util
flags.DEFINE_string('export_path', None, 'Path to export model.')
flags.DEFINE_string('frozen_graph_path', None, 'Path to frozen graph.')
flags.DEFINE_string(
'pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.')
FLAGS = flags.FLAGS
def main(_):
flags.mark_flag_as_required('export_path')
flags.mark_flag_as_required('frozen_graph_path')
flags.mark_flag_as_required('pipeline_config_path')
configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
lstm_config = configs['lstm_model']
input_arrays = ['input_video_tensor']
output_arrays = [
'TFLite_Detection_PostProcess',
'TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2',
'TFLite_Detection_PostProcess:3',
]
input_shapes = {
'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3],
}
converter = tf.lite.TFLiteConverter.from_frozen_graph(
FLAGS.frozen_graph_path, input_arrays, output_arrays,
input_shapes=input_shapes
)
converter.allow_custom_ops = True
tflite_model = converter.convert()
ofilename = os.path.join(FLAGS.export_path)
open(ofilename, "wb").write(tflite_model)
if __name__ == '__main__':
tf.app.run()
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
tensorflow/models/research:
```bash
# from tensorflow/models/research
PIPELINE_CONFIG_PATH={path to pipeline config}
TRAINED_CKPT_PREFIX=/{path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \
--output_directory ${EXPORT_DIR} \
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
TFLite FlatBuffer file by running the following:
```bash
# from tensorflow/models/research
FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
EXPORT_PATH={path to filename that will be used for export}
PIPELINE_CONFIG_PATH={path to pipeline config}
python lstm_object_detection/export_tflite_lstd_model.py \
--export_path ${EXPORT_PATH} \
--frozen_graph_path ${FROZEN_GRAPH_PATH} \
--pipeline_config_path ${PIPELINE_CONFIG_PATH}
```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application.
\ No newline at end of file
......@@ -33,68 +33,6 @@ from object_detection.protos import preprocessor_pb2
class DatasetBuilderTest(tf.test.TestCase):
def _create_tf_record(self):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(16, 16, 3)).astype(np.uint8)
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
sequence_example = example_pb2.SequenceExample(
context=feature_pb2.Features(
feature={
'image/format':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=['jpeg'.encode('utf-8')])),
'image/height':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
'image/width':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
}),
feature_lists=feature_pb2.FeatureLists(
feature_list={
'image/encoded':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[encoded_jpeg])),
]),
'image/object/bbox/xmin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/xmax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/bbox/ymin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/ymax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/class/label':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[2]))
]),
}))
writer.write(sequence_example.SerializeToString())
writer.close()
return path
def _get_model_configs_from_proto(self):
"""Creates a model text proto for testing.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl import flags
import numpy as np
import tensorflow as tf
flags.DEFINE_string('model_path', None, 'Path to model.')
FLAGS = flags.FLAGS
def main(_):
flags.mark_flag_as_required('model_path')
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print 'input_details:', input_details
output_details = interpreter.get_output_details()
print 'output_details:', output_details
# Test model on random input data.
input_shape = input_details[0]['shape']
# change the following line to feed into your own data.
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print output_data
if __name__ == '__main__':
tf.app.run()
......@@ -37,7 +37,7 @@ def get_configs_from_pipeline_file(pipeline_config_path):
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_confg`.
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment